xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
16 
17 #include <sys/mman.h>
18 
19 #include <algorithm>
20 #include <functional>
21 #include <initializer_list>
22 #include <memory>
23 
24 #include <gtest/gtest.h>
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/context_util.h"
27 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h"
28 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate_plugin.h"
29 #include "tensorflow/lite/interpreter.h"
30 #include "tensorflow/lite/kernels/test_util.h"
31 #include "tensorflow/lite/model.h"
32 #include "tensorflow/lite/nnapi/NeuralNetworksTypes.h"
33 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
34 
35 namespace tflite {
36 namespace {
37 
38 using ::testing::ElementsAre;
39 using ::testing::ElementsAreArray;
40 using ::testing::FloatNear;
41 using ::testing::Matcher;
42 
43 // TODO(b/110368244): figure out how to share the existing tests in kernels/ but
44 // with the delegation on. Also, add more unit tests to improve code coverage.
45 
46 // This matcher uses 1 as maximum tolerance.
47 MATCHER(QuantizedNear, "") {
48   const int diff = abs(std::get<0>(arg) - std::get<1>(arg));
49   if (diff > 1) {
50     *result_listener << "Quantized values can be at most off by one: " << diff;
51     return false;
52   }
53   return true;
54 }
55 
NnapiArrayFloatNear(const std::vector<float> & values,bool relaxed=false)56 auto NnapiArrayFloatNear(const std::vector<float>& values,
57                          bool relaxed = false) {
58   // Uses the same tolerance as NNAPI generated tests.
59   const float atol = relaxed ? 5 * 0.0009765625f : 1e-5f;
60   const float rtol = relaxed ? 5 * 0.0009765625f : 5 * 1.1920928955078125e-7f;
61 
62   std::vector<Matcher<float>> matchers;
63   matchers.reserve(values.size());
64   for (const float& v : values) {
65     const float tolerance = atol + rtol * std::abs(v);
66     matchers.emplace_back(FloatNear(v, tolerance));
67   }
68   return ElementsAreArray(matchers);
69 }
70 
71 class SingleOpModelWithNNAPI : public SingleOpModel {
72  public:
SingleOpModelWithNNAPI()73   SingleOpModelWithNNAPI() { options_.disallow_nnapi_cpu = false; }
74 
SingleOpModelWithNNAPI(const StatefulNnApiDelegate::Options & options)75   explicit SingleOpModelWithNNAPI(
76       const StatefulNnApiDelegate::Options& options) {
77     options_ = options;
78     options_.disallow_nnapi_cpu = false;
79   }
80 
ResizeInputTensor(int tensor_index,const std::vector<int> & dims)81   TfLiteStatus ResizeInputTensor(int tensor_index,
82                                  const std::vector<int>& dims) {
83     return interpreter_->ResizeInputTensor(tensor_index, dims);
84   }
85 
GetDelegate()86   StatefulNnApiDelegate* GetDelegate() { return stateful_delegate_.get(); }
87 
SetBufferHandle(int index,TfLiteBufferHandle handle)88   void SetBufferHandle(int index, TfLiteBufferHandle handle) {
89     interpreter_->SetBufferHandle(index, handle, stateful_delegate_.get());
90   }
91 
MarkInputTensorDataStale(int index)92   void MarkInputTensorDataStale(int index) {
93     interpreter_->tensor(index)->data_is_stale = true;
94   }
95 
AllocateTensors()96   TfLiteStatus AllocateTensors() { return interpreter_->AllocateTensors(); }
97 
SetTensorMaxSize(uint32_t tensor_index,size_t max_size)98   void SetTensorMaxSize(uint32_t tensor_index, size_t max_size) {
99     options_.tensor_max_size_hints.emplace(tensor_index, max_size);
100   }
101 
ApplyNNAPIDelegate()102   void ApplyNNAPIDelegate() {
103     stateful_delegate_ = std::make_unique<StatefulNnApiDelegate>(options_);
104     SetDelegate(stateful_delegate_.get());
105     ApplyDelegate();
106   }
107 
108  protected:
SetData(int index,TensorType type,const std::vector<float> & data)109   void SetData(int index, TensorType type, const std::vector<float>& data) {
110     switch (type) {
111       case TensorType_FLOAT32:
112         PopulateTensor(index, data);
113         break;
114       case TensorType_INT32:
115         QuantizeAndPopulate<int32_t>(index, data);
116         break;
117       case TensorType_UINT8:
118         QuantizeAndPopulate<uint8_t>(index, data);
119         break;
120       case TensorType_INT8:
121         QuantizeAndPopulate<int8_t>(index, data);
122         break;
123       default:
124         FAIL() << "Type not supported: " << type;
125         break;
126     }
127   }
128 
GetData(int index,TensorType type,std::vector<float> * output)129   void GetData(int index, TensorType type, std::vector<float>* output) {
130     switch (type) {
131       case TensorType_FLOAT32:
132         *output = ExtractVector<float>(index);
133         break;
134       case TensorType_UINT8:
135         *output = Dequantize<uint8_t>(ExtractVector<uint8_t>(index),
136                                       GetScale(index), GetZeroPoint(index));
137         break;
138       default:
139         FAIL() << "Type not supported: " << type;
140         break;
141     }
142   }
143 
BuildInterpreterWithNNAPI(std::vector<std::vector<int>> input_shapes,bool allow_fp32_relax_to_fp16=false,bool apply_delegate=true)144   void BuildInterpreterWithNNAPI(std::vector<std::vector<int>> input_shapes,
145                                  bool allow_fp32_relax_to_fp16 = false,
146                                  bool apply_delegate = true) {
147     // We skip those TfLite delegates that are applied by default in TfLite
148     // runtime by setting 'apply_delegate' to false. Afterwards, we explicitly
149     // call ApplyDelegate to apply the NNAPI delegate to meet the testing
150     // purpose.
151     BuildInterpreter(input_shapes, /*num_threads=*/-1, allow_fp32_relax_to_fp16,
152                      /*apply_delegate=*/false, /*allocate_and_delegate=*/true);
153     if (apply_delegate) {
154       ApplyNNAPIDelegate();
155     }
156   }
157 
158  private:
159   // Stateful NNAPI delegate. This is valid only if the state-ful constructor is
160   // used.
161   StatefulNnApiDelegate::Options options_;
162   std::unique_ptr<StatefulNnApiDelegate> stateful_delegate_;
163 };
164 
165 class FloatAddOpModel : public SingleOpModelWithNNAPI {
166  public:
FloatAddOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type,bool allow_fp32_relax_to_fp16=false)167   FloatAddOpModel(const TensorData& input1, const TensorData& input2,
168                   const TensorData& output,
169                   ActivationFunctionType activation_type,
170                   bool allow_fp32_relax_to_fp16 = false) {
171     Init(input1, input2, output, activation_type, allow_fp32_relax_to_fp16);
172   }
173 
FloatAddOpModel(const StatefulNnApiDelegate::Options & options,const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type,bool allow_fp32_relax_to_fp16=false)174   FloatAddOpModel(const StatefulNnApiDelegate::Options& options,
175                   const TensorData& input1, const TensorData& input2,
176                   const TensorData& output,
177                   ActivationFunctionType activation_type,
178                   bool allow_fp32_relax_to_fp16 = false)
179       : SingleOpModelWithNNAPI(options) {
180     Init(input1, input2, output, activation_type, allow_fp32_relax_to_fp16);
181   }
182 
input1()183   int input1() { return input1_; }
input2()184   int input2() { return input2_; }
185 
GetOutput()186   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
187 
188  protected:
189   int input1_;
190   int input2_;
191   int output_;
192 
193  private:
194   // Performs initialization logic shared across all constructors.
Init(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type,bool allow_fp32_relax_to_fp16=false)195   void Init(const TensorData& input1, const TensorData& input2,
196             const TensorData& output, ActivationFunctionType activation_type,
197             bool allow_fp32_relax_to_fp16 = false) {
198     input1_ = AddInput(input1);
199     input2_ = AddInput(input2);
200     output_ = AddOutput(output);
201     SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
202                  CreateAddOptions(builder_, activation_type).Union());
203     BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)},
204                               allow_fp32_relax_to_fp16);
205   }
206 };
207 
208 // Do a test with the NN API using no activation.
TEST(NNAPIDelegate,AddWithNoActivation)209 TEST(NNAPIDelegate, AddWithNoActivation) {
210   FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
211                     {TensorType_FLOAT32, {1, 2, 2, 1}},
212                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
213   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
214   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
215   ASSERT_EQ(m.Invoke(), kTfLiteOk);
216   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
217 }
218 
219 // Do a test with scalar input using no activation.
TEST(NNAPIDelegate,AddScalarWithNoActivation)220 TEST(NNAPIDelegate, AddScalarWithNoActivation) {
221   FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
222                     {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}},
223                     ActivationFunctionType_NONE);
224   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.7});
225   m.PopulateTensor<float>(m.input2(), {0.1});
226   ASSERT_EQ(m.Invoke(), kTfLiteOk);
227   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.3, 0.8, 0.8}));
228 }
229 
230 // Do a test with the NN API using no activation.
231 // The test allows computing FP32 with FP16 precision. In this particular case,
232 // calculating in FP32 or FP16 should produce the same results.
TEST(NNAPIDelegate,AddWithNoActivationRelaxed)233 TEST(NNAPIDelegate, AddWithNoActivationRelaxed) {
234   FloatAddOpModel m(
235       {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}},
236       {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE, true);
237   m.PopulateTensor<float>(m.input1(), {-2.0, -1.0, 1.0, 2.0});
238   m.PopulateTensor<float>(m.input2(), {1.0, 2.0, 3.0, 4.0});
239   ASSERT_EQ(m.Invoke(), kTfLiteOk);
240   EXPECT_THAT(m.GetOutput(),
241               NnapiArrayFloatNear({-1.0, 1.0, 4.0, 6.0}, /*relaxed=*/true));
242 }
243 
244 // Do a test with the NN api with relu.
TEST(NNAPIDelegate,AddWithRelu)245 TEST(NNAPIDelegate, AddWithRelu) {
246   FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
247                     {TensorType_FLOAT32, {1, 2, 2, 1}},
248                     {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU);
249   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
250   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
251   ASSERT_EQ(m.Invoke(), kTfLiteOk);
252   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({0.0, 0.4, 1.0, 1.3}));
253 }
254 
255 // Verify that resize attempts succeed.
TEST(NNAPIDelegate,ResizeInputTensorsWorks)256 TEST(NNAPIDelegate, ResizeInputTensorsWorks) {
257   FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
258                     {TensorType_FLOAT32, {1, 2, 2, 1}},
259                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
260 
261   EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 3, 2, 1}), kTfLiteOk);
262   EXPECT_EQ(m.ResizeInputTensor(m.input2(), {1, 3, 2, 1}), kTfLiteOk);
263   EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
264   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 0.9, 0.7});
265   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 0.2, 0.8});
266   ASSERT_EQ(m.Invoke(), kTfLiteOk);
267   EXPECT_THAT(m.GetOutput(),
268               NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3, 1.1, 1.5}));
269 
270   EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 2, 2, 1}), kTfLiteOk);
271   EXPECT_EQ(m.ResizeInputTensor(m.input2(), {1, 2, 2, 1}), kTfLiteOk);
272   EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
273   m.PopulateTensor<float>(m.input1(), {0.7, 0.8, 0.9, 0.7});
274   m.PopulateTensor<float>(m.input2(), {0.3, 0.5, 0.2, 0.8});
275   ASSERT_EQ(m.Invoke(), kTfLiteOk);
276   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1.0, 1.3, 1.1, 1.5}));
277 }
278 
TEST(NNAPIDelegate,ResizeDynamicBatchInputTensorsWorks)279 TEST(NNAPIDelegate, ResizeDynamicBatchInputTensorsWorks) {
280   StatefulNnApiDelegate::Options options;
281   options.allow_dynamic_dimensions = true;
282   options.max_execution_cache_size = 1;
283 
284   FloatAddOpModel m(options,
285                     {TensorType_FLOAT32, /*shape=*/{1, 3, 2, 1}, /*min=*/0.0f,
286                      /*max=*/0.0f, /*scale=*/0.0f,
287                      /*zero_point=*/0, /*per_channel_quantization=*/false,
288                      /*per_channel_quantization_scales=*/{},
289                      /*per_channel_quantization_offsets=*/{},
290                      /*channel_index=*/0, /*traversal_order=*/{},
291                      /*format=*/{},
292                      /*block_size=*/{}, /*block_map=*/{},
293                      /*shape_signature=*/{1, -1, 2, 1}},
294                     {TensorType_FLOAT32, /*shape=*/{1, 3, 2, 1}, /*min=*/0.0f,
295                      /*max=*/0.0f, /*scale=*/0.0f,
296                      /*zero_point=*/0, /*per_channel_quantization=*/false,
297                      /*per_channel_quantization_scales=*/{},
298                      /*per_channel_quantization_offsets=*/{},
299                      /*channel_index=*/0, /*traversal_order=*/{},
300                      /*format=*/{},
301                      /*block_size=*/{}, /*block_map=*/{},
302                      /*shape_signature=*/{1, -1, 2, 1}},
303                     {TensorType_FLOAT32, /*shape=*/{}, /*min=*/0.0f,
304                      /*max=*/0.0f, /*scale=*/0.0f,
305                      /*zero_point=*/0, /*per_channel_quantization=*/false,
306                      /*per_channel_quantization_scales=*/{},
307                      /*per_channel_quantization_offsets=*/{},
308                      /*channel_index=*/0, /*traversal_order=*/{},
309                      /*format=*/{},
310                      /*block_size=*/{}, /*block_map=*/{},
311                      /*shape_signature=*/{1, -1, 2, 1}},
312                     ActivationFunctionType_NONE);
313 
314   // Define 2 test cases, each with a different dynamic dimension value.
315   auto RunTestCase1 = [&m]() {
316     EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 3, 2, 1}), kTfLiteOk);
317     EXPECT_EQ(m.ResizeInputTensor(m.input2(), {1, 3, 2, 1}), kTfLiteOk);
318     EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
319     m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8, 0.9, 0.7});
320     m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5, 0.2, 0.8});
321     ASSERT_EQ(m.Invoke(), kTfLiteOk);
322     EXPECT_THAT(m.GetOutput(),
323                 ElementsAreArray({-1.9, 0.4, 1.0, 1.3, 1.1, 1.5}));
324   };
325   auto RunTestCase2 = [&m]() {
326     EXPECT_EQ(m.ResizeInputTensor(m.input1(), {1, 2, 2, 1}), kTfLiteOk);
327     EXPECT_EQ(m.ResizeInputTensor(m.input2(), {1, 2, 2, 1}), kTfLiteOk);
328     EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
329     m.PopulateTensor<float>(m.input1(), {0.7, 0.8, 0.9, 0.7});
330     m.PopulateTensor<float>(m.input2(), {0.3, 0.5, 0.2, 0.8});
331     ASSERT_EQ(m.Invoke(), kTfLiteOk);
332     EXPECT_THAT(m.GetOutput(), ElementsAreArray({1.0, 1.3, 1.1, 1.5}));
333   };
334 
335   // TODO(b/221070667): Find a way to test whether the execution has indeed been
336   // reused or not.
337   // This will create a new execution for case 1.
338   RunTestCase1();
339   // This will reuse the execution for case 1.
340   RunTestCase1();
341   // This will destroy case 1, and create a new execution for case 2.
342   RunTestCase2();
343   // This will destroy case 2, and create a new execution for case 1.
344   RunTestCase1();
345 }
346 
347 // Sanity check for the state-ful NNAPI delegate.
TEST(NNAPIDelegate,StatefulDelegate)348 TEST(NNAPIDelegate, StatefulDelegate) {
349   StatefulNnApiDelegate::Options options;
350   options.execution_preference =
351       StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;
352 
353   FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
354                     {TensorType_FLOAT32, {1, 2, 2, 1}},
355                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
356   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
357   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
358   ASSERT_EQ(m.Invoke(), kTfLiteOk);
359   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
360 }
361 
362 // Sanity check for the state-ful NNAPI delegate with accelerator_name
363 // specified.
TEST(NNAPIDelegate,StatefulDelegateWithAcceleratorName)364 TEST(NNAPIDelegate, StatefulDelegateWithAcceleratorName) {
365   StatefulNnApiDelegate::Options options;
366   options.execution_preference =
367       StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;
368   options.accelerator_name = "nnapi-reference";
369 
370   FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
371                     {TensorType_FLOAT32, {1, 2, 2, 1}},
372                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
373   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
374   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
375   ASSERT_EQ(m.Invoke(), kTfLiteOk);
376   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
377 }
378 
379 // Sanity check for the state-ful NNAPI delegate with invalid accelerator_name
380 // specified.
TEST(NNAPIDelegate,StatefulDelegateWithInvalidAcceleratorName)381 TEST(NNAPIDelegate, StatefulDelegateWithInvalidAcceleratorName) {
382   if (!NnApiImplementation()->ANeuralNetworksDevice_getName) {
383     GTEST_SKIP();
384   }
385   testing::internal::CaptureStderr();
386   StatefulNnApiDelegate::Options options;
387   options.execution_preference =
388       StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;
389   options.accelerator_name = "foo";
390 
391   FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
392                     {TensorType_FLOAT32, {1, 2, 2, 1}},
393                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
394   EXPECT_THAT(testing::internal::GetCapturedStderr(),
395               testing::HasSubstr(
396                   "Could not find the specified NNAPI accelerator: foo"));
397 
398   // Execution should fall back to the default CPU path.
399   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
400   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
401   ASSERT_EQ(m.Invoke(), kTfLiteOk);
402   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
403 }
404 
405 // Sanity check for the state-ful NNAPI delegate with compilation caching
406 // enabled.
TEST(NNAPIDelegate,StatefulDelegateWithCompilationCaching)407 TEST(NNAPIDelegate, StatefulDelegateWithCompilationCaching) {
408   StatefulNnApiDelegate::Options options;
409   options.execution_preference =
410       StatefulNnApiDelegate::Options::ExecutionPreference::kLowPower;
411   options.cache_dir = "/data/local/tmp";
412   options.model_token = "NNAPIDelegate.StatefulDelegateWithCompilationCaching";
413 
414   FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
415                     {TensorType_FLOAT32, {1, 2, 2, 1}},
416                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
417   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
418   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
419   ASSERT_EQ(m.Invoke(), kTfLiteOk);
420   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
421 }
422 
423 // Sanity check for the state-ful NNAPI delegate with QoS hints.
TEST(NNAPIDelegate,StatefulDelegateWithQoS)424 TEST(NNAPIDelegate, StatefulDelegateWithQoS) {
425   StatefulNnApiDelegate::Options options;
426   options.accelerator_name = "nnapi-reference";
427   options.execution_priority = ANEURALNETWORKS_PRIORITY_HIGH;
428   options.max_compilation_timeout_duration_ns = UINT64_MAX;
429   options.max_execution_timeout_duration_ns = UINT64_MAX;
430   options.max_execution_loop_timeout_duration_ns = UINT64_MAX;
431 
432   FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
433                     {TensorType_FLOAT32, {1, 2, 2, 1}},
434                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
435   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
436   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
437   ASSERT_EQ(m.Invoke(), kTfLiteOk);
438   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
439 }
440 
441 // Sanity check for the state-ful NNAPI delegate using TfLiteBufferHandle.
TEST(NNAPIDelegate,StatefulDelegateWithBufferHandles)442 TEST(NNAPIDelegate, StatefulDelegateWithBufferHandles) {
443   // Skip the test if Android specific functions could not be found.
444   if (!NnApiImplementation()->ASharedMemory_create ||
445       !NnApiImplementation()->ANeuralNetworksMemory_createFromFd) {
446     GTEST_SKIP();
447   }
448 
449   StatefulNnApiDelegate::Options options;
450   // Allow NNAPI CPU fallback path.
451   options.disallow_nnapi_cpu = false;
452   options.max_execution_cache_size = 2;
453   FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
454                     {TensorType_FLOAT32, {1, 2, 2, 1}},
455                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
456   auto* delegate = m.GetDelegate();
457   // Create ASharedMemory and copy data into it.
458   constexpr auto kInput1ByteSize = 4 * sizeof(float);
459   ANeuralNetworksMemory* input1_memory = nullptr;
460   int fd =
461       NnApiImplementation()->ASharedMemory_create("input1", kInput1ByteSize);
462   EXPECT_GE(fd, 0);
463   void* input1_memory_data =
464       mmap(nullptr, kInput1ByteSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
465   EXPECT_TRUE(input1_memory_data != nullptr);
466   float input1_data[] = {-2.0, 0.2, 0.7, 0.8};
467   memcpy(input1_memory_data, input1_data, kInput1ByteSize);
468   int result = NnApiImplementation()->ANeuralNetworksMemory_createFromFd(
469       kInput1ByteSize, PROT_READ, fd, 0, &input1_memory);
470   EXPECT_EQ(result, ANEURALNETWORKS_NO_ERROR);
471   ASSERT_NE(input1_memory, nullptr);
472 
473   struct DummyMemoryContext {
474     ANeuralNetworksMemory* memory_handle;
475     void* memory_data;
476     size_t byte_size;
477   };
478   DummyMemoryContext memory_context = {input1_memory, input1_memory_data,
479                                        kInput1ByteSize};
480   static StatefulNnApiDelegate::CopyToHostTensorFnPtr memory_callback =
481       [](TfLiteTensor* tensor, ANeuralNetworksMemory* memory,
482          size_t memory_offset, size_t byte_size,
483          void* callback_context) -> TfLiteStatus {
484     auto memory_context =
485         reinterpret_cast<DummyMemoryContext*>(callback_context);
486     if (memory != memory_context->memory_handle ||
487         memory_offset + byte_size > memory_context->byte_size) {
488       return kTfLiteError;
489     }
490     memcpy(
491         tensor->data.raw,
492         reinterpret_cast<uint8_t*>(memory_context->memory_data) + memory_offset,
493         byte_size);
494     return kTfLiteOk;
495   };
496   auto input1_handle = delegate->RegisterNnapiMemory(
497       input1_memory, memory_callback, &memory_context);
498   m.SetBufferHandle(m.input1(), input1_handle);
499   m.MarkInputTensorDataStale(m.input1());
500   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
501   ASSERT_EQ(m.Invoke(), kTfLiteOk);
502   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9, 0.4, 1.0, 1.3}));
503 
504   // Run the inference multiple times with the same buffer so that the execution
505   // can be reused.
506   for (int i = 0; i < 10; i++) {
507     // Change the value a little bit.
508     input1_data[0] = -2.0 + i;
509     memcpy(input1_memory_data, input1_data, kInput1ByteSize);
510     m.MarkInputTensorDataStale(m.input1());
511     ASSERT_EQ(m.Invoke(), kTfLiteOk);
512     EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-1.9f + i, 0.4f, 1.0f, 1.3f}));
513   }
514 
515   // Run the inference multiple times and each time register a buffer.
516   // Each will destroy the previous cache and create a new execution.
517   for (int i = 0; i < 10; i++) {
518     // Change the value a little bit.
519     input1_data[0] = -2.0 + i;
520     memcpy(input1_memory_data, input1_data, kInput1ByteSize);
521     auto input1_handle = delegate->RegisterNnapiMemory(
522         input1_memory, memory_callback, &memory_context);
523     m.SetBufferHandle(m.input1(), input1_handle);
524     m.MarkInputTensorDataStale(m.input1());
525     m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
526     ASSERT_EQ(m.Invoke(), kTfLiteOk);
527     EXPECT_THAT(m.GetOutput(),
528                 NnapiArrayFloatNear({-1.9f + i, 0.4f, 1.0f, 1.3f}));
529   }
530 }
531 
532 class FloatMulOpModel : public SingleOpModelWithNNAPI {
533  public:
FloatMulOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type)534   FloatMulOpModel(const TensorData& input1, const TensorData& input2,
535                   const TensorData& output,
536                   ActivationFunctionType activation_type) {
537     input1_ = AddInput(input1);
538     input2_ = AddInput(input2);
539     output_ = AddOutput(output);
540     SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions,
541                  CreateMulOptions(builder_, activation_type).Union());
542     BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)});
543   }
544 
input1()545   int input1() { return input1_; }
input2()546   int input2() { return input2_; }
547 
GetOutput()548   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
549 
550  protected:
551   int input1_;
552   int input2_;
553   int output_;
554 };
555 
TEST(NNAPIDelegate,MulWithNoActivation)556 TEST(NNAPIDelegate, MulWithNoActivation) {
557   FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
558                     {TensorType_FLOAT32, {1, 2, 2, 1}},
559                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
560   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
561   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
562   ASSERT_EQ(m.Invoke(), kTfLiteOk);
563   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-0.2, 0.04, 0.21, 0.4}));
564 }
565 
566 class FloatPoolingOpModel : public SingleOpModelWithNNAPI {
567  public:
FloatPoolingOpModel(BuiltinOperator type,const TensorData & input,int filter_width,int filter_height,const TensorData & output)568   FloatPoolingOpModel(BuiltinOperator type, const TensorData& input,
569                       int filter_width, int filter_height,
570                       const TensorData& output) {
571     input_ = AddInput(input);
572     output_ = AddOutput(output);
573 
574     SetBuiltinOp(
575         type, BuiltinOptions_Pool2DOptions,
576         CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width,
577                             filter_height, ActivationFunctionType_NONE)
578             .Union());
579 
580     BuildInterpreterWithNNAPI({GetShape(input_)});
581   }
582 
SetInput(std::initializer_list<float> data)583   void SetInput(std::initializer_list<float> data) {
584     PopulateTensor(input_, data);
585   }
586 
GetOutput()587   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
588 
589  protected:
590   int input_;
591   int output_;
592 };
593 
TEST(NNAPIDelegate,AveragePoolWithNoActivation)594 TEST(NNAPIDelegate, AveragePoolWithNoActivation) {
595   FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D,
596                         /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
597                         /*filter_width=*/2, /*filter_height=*/2,
598                         /*output=*/{TensorType_FLOAT32, {}});
599   m.SetInput({
600       0, 6, 2, 4,   //
601       3, 2, 10, 7,  //
602   });
603   ASSERT_EQ(m.Invoke(), kTfLiteOk);
604   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({2.75, 5.75}));
605 }
606 
TEST(NNAPIDelegate,MaxPoolWithNoActivation)607 TEST(NNAPIDelegate, MaxPoolWithNoActivation) {
608   FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D,
609                         /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
610                         /*filter_width=*/2, /*filter_height=*/2,
611                         /*output=*/{TensorType_FLOAT32, {}});
612   m.SetInput({
613       0, 6, 2, 4,   //
614       3, 2, 10, 7,  //
615   });
616   ASSERT_EQ(m.Invoke(), kTfLiteOk);
617   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({6, 10}));
618 }
619 
TEST(NNAPIDelegate,L2PoolWithNoActivation)620 TEST(NNAPIDelegate, L2PoolWithNoActivation) {
621   FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D,
622                         /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
623                         /*filter_width=*/2, /*filter_height=*/2,
624                         /*output=*/{TensorType_FLOAT32, {}});
625   m.SetInput({
626       0, 6, 2, 4,   //
627       3, 2, 10, 7,  //
628   });
629   ASSERT_EQ(m.Invoke(), kTfLiteOk);
630   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({3.5, 6.5}));
631 }
632 
633 class ConvolutionOpModel : public SingleOpModelWithNNAPI {
634  public:
ConvolutionOpModel(const TensorData & input,const TensorData & filter,const TensorData & output,int stride_width=2,int stride_height=2,enum Padding padding=Padding_VALID,enum ActivationFunctionType activation=ActivationFunctionType_NONE,int dilation_width_factor=1,int dilation_height_factor=1)635   ConvolutionOpModel(
636       const TensorData& input, const TensorData& filter,
637       const TensorData& output, int stride_width = 2, int stride_height = 2,
638       enum Padding padding = Padding_VALID,
639       enum ActivationFunctionType activation = ActivationFunctionType_NONE,
640       int dilation_width_factor = 1, int dilation_height_factor = 1)
641       : input_type_(input.type), filter_type_(filter.type) {
642     input_ = AddInput(input);
643     filter_ = AddInput(filter);
644 
645     int bias_size = GetShape(filter_)[0];
646     if (input.type == TensorType_FLOAT32) {
647       bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
648     } else {
649       // This is a quantized version. The scale of 'bias' depends on the scales
650       // of input and filter. Supposedly this is correctly set during quantized
651       // training.
652       auto bias_scale = GetScale(input_) * GetScale(filter_);
653       TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
654       bias_ = AddInput(bias);
655     }
656 
657     output_ = AddOutput(output);
658 
659     SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
660                  CreateConv2DOptions(
661                      builder_, padding, stride_width, stride_height, activation,
662                      dilation_width_factor, dilation_height_factor)
663                      .Union());
664 
665     BuildInterpreterWithNNAPI(
666         {GetShape(input_), GetShape(filter_), GetShape(bias_)});
667   }
668 
SetInput(std::initializer_list<float> data)669   void SetInput(std::initializer_list<float> data) {
670     SetData(input_, input_type_, data);
671   }
672 
SetFilter(std::initializer_list<float> data)673   void SetFilter(std::initializer_list<float> data) {
674     SetData(filter_, filter_type_, data);
675   }
676 
SetBias(std::initializer_list<float> data)677   void SetBias(std::initializer_list<float> data) {
678     const auto bias_type =
679         (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
680     SetData(bias_, bias_type, data);
681   }
682 
GetOutput()683   std::vector<float> GetOutput() {
684     if (input_type_ == TensorType_FLOAT32) {
685       return ExtractVector<float>(output_);
686     } else {
687       return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
688                                  GetScale(output_), GetZeroPoint(output_));
689     }
690   }
691 
GetQuantizedOutput()692   std::vector<uint8_t> GetQuantizedOutput() {
693     if (input_type_ == TensorType_FLOAT32) {
694       return {};  // Not supported.
695     } else {
696       return ExtractVector<uint8_t>(output_);
697     }
698   }
699 
700  protected:
701   int input_;
702   int filter_;
703   int bias_;
704   int output_;
705 
706   const TensorType input_type_;
707   const TensorType filter_type_;
708 };
709 
710 // In this tests we set the input and output scales so that the results
711 // match exactly the 'non-quantized' version.
TEST(ConvolutionOpTest,SimpleTestQuantized)712 TEST(ConvolutionOpTest, SimpleTestQuantized) {
713   ConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
714                        {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64},
715                        {TensorType_UINT8, {}, -127, 128});
716   m.SetInput({
717       // First batch
718       1, 1, 1, 1,  // row = 1
719       2, 2, 2, 2,  // row = 2
720       // Second batch
721       1, 2, 3, 4,  // row = 1
722       1, 2, 3, 4,  // row = 2
723   });
724   m.SetFilter({
725       1, 2, 3, 4,    // first 2x2 filter
726       -1, 1, -1, 1,  // second 2x2 filter
727       -1, -1, 1, 1,  // third 2x2 filter
728   });
729   m.SetBias({1, 2, 3});
730 
731   ASSERT_EQ(m.Invoke(), kTfLiteOk);
732 
733   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
734                                  {
735                                      18, 2, 5,  // first batch, left
736                                      18, 2, 5,  // first batch, right
737                                      17, 4, 3,  // second batch, left
738                                      37, 4, 3,  // second batch, right
739                                  },
740                                  1e-5)));
741   // For good  measure, let's also verify the quantized values:
742   EXPECT_THAT(m.GetQuantizedOutput(), ElementsAreArray({
743                                           145, 129, 132,  //
744                                           145, 129, 132,  //
745                                           144, 131, 130,  //
746                                           164, 131, 130,  //
747                                       }));
748 }
749 
TEST(ConvolutionOpTest,SimpleTestQuantizedGrouped)750 TEST(ConvolutionOpTest, SimpleTestQuantizedGrouped) {
751   ConvolutionOpModel m({TensorType_UINT8, {2, 2, 2, 2}, -63.5, 64},
752                        {TensorType_UINT8, {2, 2, 2, 1}, -63.5, 64},
753                        {TensorType_UINT8, {}, -127, 128});
754   m.SetInput({
755       // First batch
756       1, 1, 1, 1,  // row = 1
757       2, 2, 2, 2,  // row = 2
758       // Second batch
759       1, 2, 3, 4,  // row = 1
760       1, 2, 3, 4,  // row = 2
761   });
762   m.SetFilter({
763       1, 2, 3, 4,    // first 2x2 filter
764       -1, 1, -1, 1,  // second 2x2 filter
765   });
766   m.SetBias({1, 2});
767 
768   ASSERT_EQ(m.Invoke(), kTfLiteOk);
769 
770   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
771                                  {
772                                      18, 2,  // first batch
773                                      23, 6   // second batch
774                                  },
775                                  1e-5)));
776   // For good  measure, let's also verify the quantized values:
777   EXPECT_THAT(m.GetQuantizedOutput(), ElementsAreArray({
778                                           145, 129,  //
779                                           150, 133,  //
780                                       }));
781 }
782 
TEST(ConvolutionOpTest,FloatInputQuantizedWeights)783 TEST(ConvolutionOpTest, FloatInputQuantizedWeights) {
784   ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
785                        {TensorType_UINT8, {3, 2, 2, 1}, 0, 64},
786                        {TensorType_FLOAT32, {}});
787   m.SetInput({
788       // First batch
789       1, 1, 1, 2,  // row = 1
790       2, 2, 2, 1,  // row = 2
791       // Second batch
792       1, 2, 3, 4,  // row = 1
793       1, 2, 3, 4,  // row = 2
794   });
795   m.SetFilter({
796       1, 2, 3, 4,  // first 2x2 filter
797       0, 1, 0, 1,  // second 2x2 filter
798       0, 0, 1, 1,  // third 2x2 filter
799   });
800   m.SetBias({1, 2, 3});
801 
802   ASSERT_EQ(m.Invoke(), kTfLiteOk);
803 
804   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
805                                  {
806                                      18, 5, 7,    // first batch, left
807                                      16, 5, 6,    // first batch, right
808                                      17, 6, 6,    // second batch, left
809                                      37, 10, 10,  // second batch, right
810                                  },
811                                  0.2)));
812 }
813 
TEST(ConvolutionOpTest,NoActivation)814 TEST(ConvolutionOpTest, NoActivation) {
815   ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
816                        {TensorType_FLOAT32, {3, 2, 2, 1}},
817                        {TensorType_FLOAT32, {}});
818 
819   m.SetInput({
820       // First batch
821       1, 1, 1, 1,  // row = 1
822       2, 2, 2, 2,  // row = 2
823       // Second batch
824       1, 2, 3, 4,  // row = 1
825       1, 2, 3, 4,  // row = 2
826   });
827   m.SetFilter({
828       1, 2, 3, 4,    // first 2x2 filter
829       -1, 1, -1, 1,  // second 2x2 filter
830       -1, -1, 1, 1,  // third 2x2 filter
831   });
832   m.SetBias({1, 2, 3});
833 
834   ASSERT_EQ(m.Invoke(), kTfLiteOk);
835 
836   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
837                                  18, 2, 5,  // first batch, left
838                                  18, 2, 5,  // first batch, right
839                                  17, 4, 3,  // second batch, left
840                                  37, 4, 3,  // second batch, right
841                              }));
842 }
843 
TEST(ConvolutionOpTest,SimpleTestQuantizedOutputMultiplierGreaterThan1)844 TEST(ConvolutionOpTest, SimpleTestQuantizedOutputMultiplierGreaterThan1) {
845   // output_multiplier = 1.0118
846   ConvolutionOpModel quant_op({TensorType_UINT8, {2, 2, 4, 1}, -128.5, 128},
847                               {TensorType_UINT8, {3, 2, 2, 1}, -128.5, 128},
848                               {TensorType_UINT8, {}, -127, 128});
849   ConvolutionOpModel float_op({TensorType_FLOAT32, {2, 2, 4, 1}},
850                               {TensorType_FLOAT32, {3, 2, 2, 1}},
851                               {TensorType_FLOAT32, {}});
852   std::initializer_list<float> input = {
853       // First batch
854       1, 1, 1, 1,  // row = 1
855       2, 2, 2, 2,  // row = 2
856       // Second batch
857       1, 2, 3, 4,  // row = 1
858       1, 2, 3, 4,  // row = 2
859   };
860   std::initializer_list<float> filter = {
861       1,  2,  3,  4,  // first 2x2 filter
862       -1, 1,  -1, 1,  // second 2x2 filter
863       -1, -1, 1,  1,  // third 2x2 filter
864   };
865   std::initializer_list<float> bias = {1, 2, 3};
866 
867   quant_op.SetInput(input);
868   quant_op.SetFilter(filter);
869   quant_op.SetBias(bias);
870   ASSERT_EQ(quant_op.Invoke(), kTfLiteOk);
871 
872   float_op.SetInput(input);
873   float_op.SetFilter(filter);
874   float_op.SetBias(bias);
875   ASSERT_EQ(float_op.Invoke(), kTfLiteOk);
876 
877   EXPECT_THAT(quant_op.GetOutput(),
878               ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
879 }
880 
TEST(ConvolutionOpTest,SimpleTestFloatWithDilation)881 TEST(ConvolutionOpTest, SimpleTestFloatWithDilation) {
882   const int depth = 1;
883   const int image_width = 9;
884   const int image_height = 9;
885   const int image_batch_count = 1;
886   const int filter_size = 3;
887   const int filter_count = 1;
888   const int stride_width = 1;
889   const int stride_height = 1;
890   const int dilation_width_factor = 3;
891   const int dilation_height_factor = 3;
892   const Padding padding = Padding_VALID;
893   ConvolutionOpModel m(
894       {TensorType_FLOAT32,
895        {image_batch_count, image_height, image_width, depth}},
896       {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
897       {TensorType_FLOAT32, {}}, stride_width, stride_height, padding,
898       ActivationFunctionType_NONE, dilation_width_factor,
899       dilation_height_factor);
900 
901   // The image matrix is:
902   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
903   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
904   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
905   // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
906   // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
907   // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
908   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
909   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
910   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
911   // clang-format off
912   m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
913               0, 0, 0, 0, 0, 0, 0, 0, 0,
914               0, 0, 0, 0, 0, 0, 0, 0, 0,
915               0, 0, 0, 1, 1, 1, 0, 0, 0,
916               0, 0, 0, 1, 1, 1, 0, 0, 0,
917               0, 0, 0, 1, 1, 1, 0, 0, 0,
918               0, 0, 0, 0, 0, 0, 0, 0, 0,
919               0, 0, 0, 0, 0, 0, 0, 0, 0,
920               0, 0, 0, 0, 0, 0, 0, 0, 0});
921   // clang-format on
922   // The filter matrix is:
923   // | 1 | 2 | 3 |
924   // | 4 | 5 | 6 |
925   // | 7 | 8 | 9 |
926   m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
927   // Zero bias for this test.
928   m.SetBias({0});
929   ASSERT_EQ(m.Invoke(), kTfLiteOk);
930 
931   // Since the dilation rate is 3 this will reduce the size of the output from
932   // 10x10 to 3x3 of all 5s. Specifically:
933   // | 5 | 5 | 5 |
934   // | 5 | 5 | 5 |
935   // | 5 | 5 | 5 |
936   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({5, 5, 5, 5, 5, 5, 5, 5, 5}));
937 }
938 
939 class QuantizedConvolutionOpModel : public ConvolutionOpModel {
940  public:
941   using ConvolutionOpModel::ConvolutionOpModel;
942 
SetInput(std::initializer_list<float> data)943   void SetInput(std::initializer_list<float> data) {
944     QuantizeAndPopulate<uint8_t>(input_, data);
945   }
946 
SetFilter(std::initializer_list<float> data)947   void SetFilter(std::initializer_list<float> data) {
948     QuantizeAndPopulate<uint8_t>(filter_, data);
949   }
950 
SetBias(std::initializer_list<float> data)951   void SetBias(std::initializer_list<float> data) {
952     QuantizeAndPopulate<int32_t>(bias_, data);
953   }
954 
GetOutput()955   std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
GetDequantizedOutput()956   std::vector<float> GetDequantizedOutput() {
957     return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
958                                GetScale(output_), GetZeroPoint(output_));
959   }
960 };
961 
TEST(ConvolutionOpTest,SimpleTestQuantizedWithDilation)962 TEST(ConvolutionOpTest, SimpleTestQuantizedWithDilation) {
963   const int depth = 1;
964   const int image_width = 9;
965   const int image_height = 9;
966   const int image_batch_count = 1;
967   const int filter_size = 3;
968   const int filter_count = 1;
969   const int stride_width = 1;
970   const int stride_height = 1;
971   const int dilation_width_factor = 3;
972   const int dilation_height_factor = 3;
973   const Padding padding = Padding_VALID;
974   ConvolutionOpModel m({TensorType_UINT8,
975                         {image_batch_count, image_height, image_width, depth},
976                         0,
977                         127.5},
978                        {TensorType_UINT8,
979                         {depth, filter_size, filter_size, filter_count},
980                         0,
981                         127.5},
982                        {TensorType_UINT8, {}, 0, 255}, stride_width,
983                        stride_height, padding, ActivationFunctionType_NONE,
984                        dilation_width_factor, dilation_height_factor);
985 
986   // The image matrix is:
987   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
988   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
989   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
990   // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
991   // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
992   // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
993   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
994   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
995   // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
996   // clang-format off
997   m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
998               0, 0, 0, 0, 0, 0, 0, 0, 0,
999               0, 0, 0, 0, 0, 0, 0, 0, 0,
1000               0, 0, 0, 1, 1, 1, 0, 0, 0,
1001               0, 0, 0, 1, 1, 1, 0, 0, 0,
1002               0, 0, 0, 1, 1, 1, 0, 0, 0,
1003               0, 0, 0, 0, 0, 0, 0, 0, 0,
1004               0, 0, 0, 0, 0, 0, 0, 0, 0,
1005               0, 0, 0, 0, 0, 0, 0, 0, 0});
1006   // clang-format on
1007   // The filter matrix is:
1008   // | 1 | 2 | 3 |
1009   // | 4 | 5 | 6 |
1010   // | 7 | 8 | 9 |
1011   m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
1012   // Zero bias for this test.
1013   m.SetBias({0});
1014   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1015 
1016   // Since the dilation rate is 3 this will reduce the size of the output from
1017   // 10x10 to 3x3 of all 5s. Specifically:
1018   // | 5 | 5 | 5 |
1019   // | 5 | 5 | 5 |
1020   // | 5 | 5 | 5 |
1021   EXPECT_THAT(m.GetQuantizedOutput(),
1022               ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
1023 }
1024 
1025 class PerChannelQuantizedConvolutionWithConstantFilterOpModel
1026     : public SingleOpModelWithNNAPI {
1027  public:
PerChannelQuantizedConvolutionWithConstantFilterOpModel(const TensorData & input,const TensorData & filter,std::initializer_list<int8_t> filter_data,std::initializer_list<int32_t> bias_data,const TensorData & output,int stride_width=2,int stride_height=2,enum Padding padding=Padding_VALID,enum ActivationFunctionType activation=ActivationFunctionType_NONE,int dilation_width_factor=1,int dilation_height_factor=1)1028   PerChannelQuantizedConvolutionWithConstantFilterOpModel(
1029       const TensorData& input, const TensorData& filter,
1030       std::initializer_list<int8_t> filter_data,
1031       std::initializer_list<int32_t> bias_data, const TensorData& output,
1032       int stride_width = 2, int stride_height = 2,
1033       enum Padding padding = Padding_VALID,
1034       enum ActivationFunctionType activation = ActivationFunctionType_NONE,
1035       int dilation_width_factor = 1, int dilation_height_factor = 1)
1036       : input_type_(input.type), filter_type_(filter.type) {
1037     CHECK(filter.per_channel_quantization);
1038     input_ = AddInput(input);
1039     filter_ = AddConstInput(filter, filter_data);
1040 
1041     const int bias_size = GetShape(filter_)[0];
1042     const int num_channels = filter.per_channel_quantization_scales.size();
1043     const std::vector<int64_t> bias_offsets(num_channels, 0);
1044     std::vector<float> bias_scales(num_channels);
1045     for (int i = 0; i < num_channels; i++) {
1046       bias_scales[i] = input.scale * filter.per_channel_quantization_scales[i];
1047     }
1048     const TensorData bias{TensorType_INT32,
1049                           {bias_size},
1050                           /*min=*/0,
1051                           /*max=*/0,
1052                           /*scale=*/0,
1053                           /*zero_point=*/0,
1054                           /*per_channel_quantization=*/true,
1055                           /*per_channel_quantization_scales=*/bias_scales,
1056                           /*per_channel_quantization_offsets=*/bias_offsets,
1057                           /*channel_index==*/0};
1058     bias_ = AddConstInput(bias, bias_data);
1059 
1060     output_ = AddOutput(output);
1061 
1062     SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
1063                  CreateConv2DOptions(
1064                      builder_, padding, stride_width, stride_height, activation,
1065                      dilation_width_factor, dilation_height_factor)
1066                      .Union());
1067 
1068     BuildInterpreterWithNNAPI(
1069         {GetShape(input_), GetShape(filter_), GetShape(bias_)});
1070   }
1071 
SetInput(std::initializer_list<float> data)1072   void SetInput(std::initializer_list<float> data) {
1073     QuantizeAndPopulate<int8_t>(input_, data);
1074   }
1075 
GetOutput()1076   std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }
1077 
1078  protected:
1079   int input_;
1080   int filter_;
1081   int bias_;
1082   int output_;
1083 
1084   const TensorType input_type_;
1085   const TensorType filter_type_;
1086 };
1087 
TEST(ConvolutionOpTest,SimplePerChannelTest)1088 TEST(ConvolutionOpTest, SimplePerChannelTest) {
1089   PerChannelQuantizedConvolutionWithConstantFilterOpModel m(
1090       {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1},
1091       {TensorType_INT8,
1092        // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
1093        {2, 2, 2, 2},
1094        /*min=*/0,
1095        /*max=*/0,
1096        /*scale=*/0,
1097        /*zero_point=*/0,
1098        /*per_channel_quantization=*/true,
1099        /*per_channel_quantization_scales=*/{1, 2},
1100        /*per_channel_quantization_offsets=*/{0, 0},
1101        /*channel_index=*/0},
1102       /*filter_data=*/
1103       {
1104           // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
1105           1, 2,  // out channel = 0, y = 0, x = 0
1106           3, 4,  // out channel = 0, y = 0, x = 1
1107           3, 4,  // out channel = 0, y = 1, x = 0
1108           5, 6,  // out channel = 0, y = 1, x = 1
1109           4, 4,  // out channel = 1, y = 0, x = 0
1110           3, 3,  // out channel = 1, y = 0, x = 1
1111           2, 2,  // out channel = 1, y = 1, x = 0
1112           1, 1,  // out channel = 1, y = 1, x = 1
1113       },
1114       /*bias_data=*/{6, -2}, {TensorType_INT8, {}, -63.5, 64, 0.5, -1},
1115       /*stride_width=*/1, /*stride_height=*/1);
1116   m.SetInput({
1117       // [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
1118       3, 2,    // batch = 0, y = 0, x = 0
1119       1, -1,   // batch = 0, y = 0, x = 1
1120       -2, -3,  // batch = 0, y = 0, x = 2
1121       4, 3,    // batch = 0, y = 1, x = 0
1122       2, -2,   // batch = 0, y = 1, x = 1
1123       -3, -4,  // batch = 0, y = 1, x = 2
1124   });
1125 
1126   // Invoke and verify output.
1127   // output has dimension [1 * 1 * 2 * 2] as [batch, y, x, output_channel]
1128   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1129   EXPECT_THAT(m.GetOutput(),
1130               testing::Pointwise(QuantizedNear(), {61, 127, -115, -93}));
1131 }
1132 
1133 class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI {
1134  public:
DepthwiseConvolutionOpModel(const TensorData & input,const TensorData & filter,const TensorData & output)1135   DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter,
1136                               const TensorData& output)
1137       : input_type_(input.type) {
1138     input_ = AddInput(input);
1139     filter_ = AddInput(filter);
1140 
1141     int bias_size = GetShape(filter_)[3];
1142     if (input.type == TensorType_FLOAT32) {
1143       bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
1144     } else {
1145       // This is a quantized version. The scale of 'bias' depends on the scales
1146       // of input and filter. Supposedly this is correctly set during quantized
1147       // training.
1148       auto bias_scale = GetScale(input_) * GetScale(filter_);
1149       TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
1150       bias_ = AddInput(bias);
1151     }
1152 
1153     output_ = AddOutput(output);
1154 
1155     int input_depth = GetShape(input_)[3];
1156     int output_depth = GetShape(filter_)[3];
1157     int depth_mul = output_depth / input_depth;
1158 
1159     SetBuiltinOp(
1160         BuiltinOperator_DEPTHWISE_CONV_2D,
1161         BuiltinOptions_DepthwiseConv2DOptions,
1162         CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
1163                                      ActivationFunctionType_NONE)
1164             .Union());
1165 
1166     BuildInterpreterWithNNAPI(
1167         {GetShape(input_), GetShape(filter_), GetShape(bias_)});
1168   }
1169 
SetInput(std::initializer_list<float> data)1170   void SetInput(std::initializer_list<float> data) {
1171     SetData(input_, input_type_, data);
1172   }
1173 
SetFilter(std::initializer_list<float> data)1174   void SetFilter(std::initializer_list<float> data) {
1175     SetData(filter_, input_type_, data);
1176   }
1177 
SetBias(std::initializer_list<float> data)1178   void SetBias(std::initializer_list<float> data) {
1179     const auto bias_type =
1180         (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
1181     SetData(bias_, bias_type, data);
1182   }
1183 
GetOutput()1184   std::vector<float> GetOutput() {
1185     if (input_type_ == TensorType_FLOAT32) {
1186       return ExtractVector<float>(output_);
1187     } else {
1188       return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
1189                                  GetScale(output_), GetZeroPoint(output_));
1190     }
1191   }
1192 
1193  protected:
1194   int input_;
1195   int filter_;
1196   int bias_;
1197   int output_;
1198 
1199   const TensorType input_type_;
1200 };
1201 
TEST(NNAPIDelegate,DepthwiseConv2DWithNoActivation)1202 TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) {
1203   DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
1204                                 {TensorType_FLOAT32, {1, 2, 2, 4}},
1205                                 {TensorType_FLOAT32, {}});
1206 
1207   m.SetInput({
1208       1, 2, 7, 8,    // column 1
1209       3, 4, 9, 10,   // column 2
1210       5, 6, 11, 12,  // column 3
1211   });
1212   m.SetFilter({
1213       1, 2, 3, 4,        //
1214       -9, 10, -11, 12,   //
1215       5, 6, 7, 8,        //
1216       13, -14, 15, -16,  //
1217   });
1218   m.SetBias({1, 2, 3, 4});
1219 
1220   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1221 
1222   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
1223                                  71, -34, 99, -20,  //
1224                                  91, -26, 127, -4,  //
1225                              }));
1226 }
1227 
TEST(QuantizedDepthwiseConv2DTest,FilterMultiplierGreaterThan1)1228 TEST(QuantizedDepthwiseConv2DTest, FilterMultiplierGreaterThan1) {
1229   DepthwiseConvolutionOpModel quant_op(
1230       {TensorType_UINT8, {1, 3, 2, 2}, -128.5, 128},
1231       {TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128},
1232       {TensorType_UINT8, {}, -127, 128});
1233   DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}},
1234                                        {TensorType_FLOAT32, {1, 2, 2, 4}},
1235                                        {TensorType_FLOAT32, {}});
1236 
1237   std::initializer_list<float> input = {
1238       1, 2, 7,  8,   // column 1
1239       3, 4, 9,  10,  // column 2
1240       5, 6, 11, 12,  // column 3
1241   };
1242   std::initializer_list<float> filter = {
1243       1,  2,   3,   4,    //
1244       -9, 10,  -11, 12,   //
1245       5,  6,   7,   8,    //
1246       13, -14, 15,  -16,  //
1247   };
1248   std::initializer_list<float> bias = {1, 2, 3, 4};
1249 
1250   quant_op.SetInput(input);
1251   quant_op.SetFilter(filter);
1252   quant_op.SetBias(bias);
1253   ASSERT_EQ(quant_op.Invoke(), kTfLiteOk);
1254 
1255   float_op.SetInput(input);
1256   float_op.SetFilter(filter);
1257   float_op.SetBias(bias);
1258   ASSERT_EQ(float_op.Invoke(), kTfLiteOk);
1259 
1260   EXPECT_THAT(quant_op.GetOutput(),
1261               ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
1262 }
1263 
1264 class FullyConnectedOpModel : public SingleOpModelWithNNAPI {
1265  public:
FullyConnectedOpModel(const TensorData & input,const TensorData & weights,const TensorData & output,enum ActivationFunctionType activation=ActivationFunctionType_NONE)1266   FullyConnectedOpModel(
1267       const TensorData& input, const TensorData& weights,
1268       const TensorData& output,
1269       enum ActivationFunctionType activation = ActivationFunctionType_NONE)
1270       : input_type_(input.type), weights_type_(weights.type) {
1271     input_ = AddInput(input);
1272     weights_ = AddInput(weights);
1273 
1274     const int units = weights.shape[0];
1275     if (input.type == TensorType_FLOAT32) {
1276       bias_ = AddInput({TensorType_FLOAT32, {units}});
1277     } else {
1278       // This is a quantized version. The scale of 'bias' depends on the scales
1279       // of input and filter. Supposedly this is correctly set during quantized
1280       // training.
1281       auto bias_scale = GetScale(input_) * GetScale(weights_);
1282       TensorData bias{TensorType_INT32, {units}, 0, 0, bias_scale};
1283       bias_ = AddInput(bias);
1284     }
1285 
1286     output_ = AddOutput(output);
1287 
1288     SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
1289                  BuiltinOptions_FullyConnectedOptions,
1290                  CreateFullyConnectedOptions(builder_, activation).Union());
1291     BuildInterpreterWithNNAPI(
1292         {GetShape(input_), GetShape(weights_), GetShape(bias_)});
1293   }
1294 
SetInput(std::initializer_list<float> data)1295   void SetInput(std::initializer_list<float> data) {
1296     SetData(input_, input_type_, data);
1297   }
1298 
SetWeights(std::initializer_list<float> data)1299   void SetWeights(std::initializer_list<float> data) {
1300     SetData(weights_, weights_type_, data);
1301   }
1302 
SetBias(std::initializer_list<float> data)1303   void SetBias(std::initializer_list<float> data) {
1304     const auto bias_type =
1305         (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
1306     SetData(bias_, bias_type, data);
1307   }
1308 
GetOutput()1309   std::vector<float> GetOutput() {
1310     if (input_type_ == TensorType_FLOAT32) {
1311       return ExtractVector<float>(output_);
1312     } else {
1313       return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
1314                                  GetScale(output_), GetZeroPoint(output_));
1315     }
1316   }
1317 
1318  protected:
1319   int input_;
1320   int weights_;
1321   int bias_;
1322   int output_;
1323 
1324   const TensorType input_type_;
1325   const TensorType weights_type_;
1326 };
1327 
TEST(FullyConnectedOpTest,SimpleTest)1328 TEST(FullyConnectedOpTest, SimpleTest) {
1329   FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}},
1330                           /*weights=*/{TensorType_FLOAT32, {3, 10}},
1331                           /*output=*/{TensorType_FLOAT32});
1332   m.SetWeights({
1333       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1334       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1335       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1336   });
1337   m.SetBias({1, 2, 3});
1338 
1339   m.SetInput({
1340       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1341       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1342   });
1343 
1344   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1345 
1346   EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
1347 }
1348 
TEST(FullyConnectedOpTest,FloatInputQuantizedWeights)1349 TEST(FullyConnectedOpTest, FloatInputQuantizedWeights) {
1350   FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}},
1351                           /*weights=*/{TensorType_UINT8, {3, 10}, 0, 64},
1352                           /*output=*/{TensorType_FLOAT32});
1353   m.SetWeights({
1354       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1355       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1356       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1357   });
1358   m.SetBias({1, 2, 3});
1359 
1360   m.SetInput({
1361       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1362       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1363   });
1364 
1365   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1366 
1367   EXPECT_THAT(m.GetOutput(),
1368               ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60}, 1.3)));
1369 }
1370 
TEST(FullyConnectedOpTest,QuantizedOutputMultiplierGreaterThan1)1371 TEST(FullyConnectedOpTest, QuantizedOutputMultiplierGreaterThan1) {
1372   // real_multiplier = 2.
1373   FullyConnectedOpModel m(
1374       /*input=*/{TensorType_UINT8, {2, 10}, -127, 128},
1375       /*weights=*/{TensorType_UINT8, {3, 10}, -127, 128},
1376       /*output=*/{TensorType_UINT8, {}, -63.5, 64});
1377 
1378   m.SetWeights({
1379       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
1380       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
1381       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 2
1382   });
1383   m.SetBias({1, 2, 3});
1384 
1385   m.SetInput({
1386       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
1387       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
1388   });
1389 
1390   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1391 
1392   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
1393                                  24, 25, 26,  // first batch
1394                                  58, 59, 60,  // second batch
1395                              })));
1396 }
1397 
1398 class SoftmaxOpModel : public SingleOpModelWithNNAPI {
1399  public:
SoftmaxOpModel(const TensorData & input,float beta)1400   SoftmaxOpModel(const TensorData& input, float beta) {
1401     input_ = AddInput(input);
1402     output_ = AddOutput(input);
1403     SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
1404                  CreateSoftmaxOptions(builder_, beta).Union());
1405     BuildInterpreterWithNNAPI({GetShape(input_)});
1406   }
1407 
SetInput(std::initializer_list<float> data)1408   void SetInput(std::initializer_list<float> data) {
1409     PopulateTensor(input_, data);
1410   }
1411 
SetInput(int offset,float * begin,float * end)1412   void SetInput(int offset, float* begin, float* end) {
1413     PopulateTensor(input_, offset, begin, end);
1414   }
1415 
GetOutput()1416   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1417 
1418  private:
1419   int input_;
1420   int output_;
1421 };
1422 
TEST(SoftmaxOpTest,SimpleTest)1423 TEST(SoftmaxOpTest, SimpleTest) {
1424   SoftmaxOpModel m({TensorType_FLOAT32, {2, 5}}, /*beta=*/1.0);
1425   m.SetInput({
1426       1.0, 2.0, 3.0, 4.0, 5.0,       // b = 0
1427       -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 1
1428   });
1429 
1430   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1431 
1432   EXPECT_THAT(
1433       m.GetOutput(),
1434       NnapiArrayFloatNear({0.011656231, 0.031684921, 0.086128544, 0.234121657,
1435                            0.636408647, 0.636408647, 0.234121657, 0.086128544,
1436                            0.031684921, 0.011656231}));
1437 }
1438 
TEST(SoftmaxOpTest,Beta2)1439 TEST(SoftmaxOpTest, Beta2) {
1440   SoftmaxOpModel m({TensorType_FLOAT32, {1, 5}}, /*beta=*/2.0);
1441   m.SetInput({
1442       1.0, 2.0, 3.0, 4.0, 5.0,  // b = 0
1443   });
1444 
1445   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1446 
1447   EXPECT_THAT(m.GetOutput(),
1448               NnapiArrayFloatNear({0.000290076, 0.002143387, 0.015837606,
1449                                    0.117024957, 0.864703974}));
1450 }
1451 
1452 TEST(SoftmaxOpTest, 3dInput) {
1453   SoftmaxOpModel m({TensorType_FLOAT32, {2, 2, 5}}, /*beta=*/1.0);
1454   m.SetInput({
1455       1.0,  2.0,  3.0,  4.0,  5.0,   // b = 0
1456       -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 0
1457       5.0,  1.0,  2.0,  3.0,  4.0,   // b = 1
1458       -5.0, -1.0, -2.0, -3.0, -4.0,  // b = 1
1459   });
1460 
1461   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1462 
1463   EXPECT_THAT(
1464       m.GetOutput(),
1465       NnapiArrayFloatNear(
1466           {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647,
1467            0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231,
1468            0.636408647, 0.011656231, 0.031684921, 0.086128544, 0.234121657,
1469            0.011656231, 0.636408647, 0.234121657, 0.086128544, 0.031684921}));
1470 }
1471 
1472 TEST(SoftmaxOpTest, 4dInput) {
1473   SoftmaxOpModel m({TensorType_FLOAT32, {2, 2, 1, 5}}, /*beta=*/1.0);
1474   m.SetInput({
1475       1.0,  2.0,  3.0,  4.0,  5.0,   // b = 0
1476       -1.0, -2.0, -3.0, -4.0, -5.0,  // b = 0
1477       5.0,  1.0,  2.0,  3.0,  4.0,   // b = 1
1478       -5.0, -1.0, -2.0, -3.0, -4.0,  // b = 1
1479   });
1480 
1481   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1482 
1483   EXPECT_THAT(
1484       m.GetOutput(),
1485       NnapiArrayFloatNear(
1486           {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647,
1487            0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231,
1488            0.636408647, 0.011656231, 0.031684921, 0.086128544, 0.234121657,
1489            0.011656231, 0.636408647, 0.234121657, 0.086128544, 0.031684921}));
1490 }
1491 
1492 class ReshapeOpModel : public SingleOpModelWithNNAPI {
1493  public:
ReshapeOpModel(std::initializer_list<int> input_shape,std::initializer_list<int> new_shape)1494   ReshapeOpModel(std::initializer_list<int> input_shape,
1495                  std::initializer_list<int> new_shape) {
1496     input_ = AddInput(TensorType_FLOAT32);
1497     new_shape_ = AddConstInput<int>(TensorType_INT32, new_shape,
1498                                     {static_cast<int>(new_shape.size())});
1499     output_ = AddOutput(TensorType_FLOAT32);
1500     SetBuiltinOp(
1501         BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions,
1502         CreateReshapeOptions(builder_, builder_.CreateVector<int>(new_shape))
1503             .Union());
1504     BuildInterpreterWithNNAPI(
1505         {input_shape, {static_cast<int>(new_shape.size())}});
1506   }
1507 
SetInput(std::initializer_list<float> data)1508   void SetInput(std::initializer_list<float> data) {
1509     PopulateTensor<float>(input_, data);
1510   }
GetOutput()1511   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1512   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1513 
1514  private:
1515   int input_;
1516   int new_shape_;
1517   int output_;
1518 };
1519 
TEST(NNAPIDelegate,ReshapeSimpleTest)1520 TEST(NNAPIDelegate, ReshapeSimpleTest) {
1521   ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2});
1522   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
1523   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1524   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 2, 3, 4, 5, 6, 7, 8}));
1525   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
1526 }
1527 
1528 class SqueezeOpModel : public SingleOpModelWithNNAPI {
1529  public:
SqueezeOpModel(const TensorData & input,const TensorData & output,std::initializer_list<int> axis)1530   SqueezeOpModel(const TensorData& input, const TensorData& output,
1531                  std::initializer_list<int> axis) {
1532     input_ = AddInput(input);
1533     output_ = AddOutput(output);
1534     SetBuiltinOp(
1535         BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions,
1536         CreateSqueezeOptions(builder_, builder_.CreateVector<int>(axis))
1537             .Union());
1538     BuildInterpreterWithNNAPI({GetShape(input_)});
1539   }
1540 
SetInput(std::initializer_list<float> data)1541   void SetInput(std::initializer_list<float> data) {
1542     PopulateTensor<float>(input_, data);
1543   }
GetOutput()1544   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1545   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1546 
1547  private:
1548   int input_;
1549   int new_shape_;
1550   int output_;
1551 };
1552 
1553 // TODO(b/215935381): Enable after resolving issues with flakiness.
TEST(NNAPIDelegate,DISABLED_SqueezeSimpleTest)1554 TEST(NNAPIDelegate, DISABLED_SqueezeSimpleTest) {
1555   std::initializer_list<float> data = {
1556       1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
1557       13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
1558   SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}},
1559                    {});
1560   m.SetInput(data);
1561   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1562   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24}));
1563   EXPECT_THAT(
1564       m.GetOutput(),
1565       NnapiArrayFloatNear({1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
1566                            9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
1567                            17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
1568 }
1569 
TEST(NNAPIDelegate,SqueezeWithAxisTest)1570 TEST(NNAPIDelegate, SqueezeWithAxisTest) {
1571   std::initializer_list<float> data = {
1572       1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,  10.0, 11.0, 12.0,
1573       13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
1574   SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}},
1575                    {2});
1576   m.SetInput(data);
1577   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1578   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24}));
1579   EXPECT_THAT(
1580       m.GetOutput(),
1581       NnapiArrayFloatNear({1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
1582                            9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
1583                            17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
1584 }
1585 
1586 class L2NormOpModel : public SingleOpModelWithNNAPI {
1587  public:
L2NormOpModel(const TensorData & input,const TensorData & output,ActivationFunctionType activation_type)1588   L2NormOpModel(const TensorData& input, const TensorData& output,
1589                 ActivationFunctionType activation_type) {
1590     input_ = AddInput(input);
1591     output_ = AddOutput(output);
1592     SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
1593                  CreateL2NormOptions(builder_, activation_type).Union());
1594     BuildInterpreterWithNNAPI({GetShape(input_)});
1595   }
1596 
SetInput(std::initializer_list<float> data)1597   void SetInput(std::initializer_list<float> data) {
1598     PopulateTensor<float>(input_, data);
1599   }
GetOutput()1600   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1601   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1602 
1603  private:
1604   int input_;
1605   int new_shape_;
1606   int output_;
1607 };
1608 
TEST(NNAPIDelegate,L2NormSimpleTest)1609 TEST(NNAPIDelegate, L2NormSimpleTest) {
1610   std::initializer_list<float> data = {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1};
1611   L2NormOpModel m({TensorType_FLOAT32, {1, 1, 1, 6}},
1612                   {TensorType_FLOAT32, {1, 1, 1, 6}},
1613                   ActivationFunctionType_NONE);
1614   m.SetInput(data);
1615   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1616   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 6}));
1617   EXPECT_THAT(m.GetOutput(),
1618               NnapiArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
1619 }
1620 
1621 class TransposeSimpleModel : public SingleOpModelWithNNAPI {
1622  public:
TransposeSimpleModel(std::initializer_list<int> input_shape,std::initializer_list<int> perm_shape,std::initializer_list<int> perm)1623   TransposeSimpleModel(std::initializer_list<int> input_shape,
1624                        std::initializer_list<int> perm_shape,
1625                        std::initializer_list<int> perm) {
1626     input_ = AddInput(TensorType_FLOAT32);
1627     perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
1628     output_ = AddOutput(TensorType_FLOAT32);
1629     SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
1630                  CreateTransposeOptions(builder_).Union());
1631     BuildInterpreterWithNNAPI({input_shape, perm_shape});
1632   }
1633 
SetInput(std::initializer_list<float> data)1634   void SetInput(std::initializer_list<float> data) {
1635     PopulateTensor<float>(input_, data);
1636   }
1637 
GetOutput()1638   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1639   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1640 
1641  private:
1642   int input_;
1643   int perm_;
1644   int output_;
1645 };
1646 
TEST(NNAPIDelegate,TransposeSimpleTest)1647 TEST(NNAPIDelegate, TransposeSimpleTest) {
1648   TransposeSimpleModel m({2, 3, 4}, {3}, {2, 0, 1});
1649   m.SetInput({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
1650               12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
1651   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1652   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
1653   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear(
1654                                  {0, 4, 8,  12, 16, 20, 1, 5, 9,  13, 17, 21,
1655                                   2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
1656 }
1657 
1658 class ElementwiseOpBaseModel : public SingleOpModelWithNNAPI {
1659  public:
input() const1660   int input() const { return input_; }
output() const1661   int output() const { return output_; }
1662 
1663  protected:
1664   int input_;
1665   int output_;
1666 };
1667 
1668 class ElementwiseOpFloatModel : public ElementwiseOpBaseModel {
1669  public:
ElementwiseOpFloatModel(BuiltinOperator op,std::initializer_list<int> input_shape)1670   ElementwiseOpFloatModel(BuiltinOperator op,
1671                           std::initializer_list<int> input_shape) {
1672     input_ = AddInput(TensorType_FLOAT32);
1673     output_ = AddOutput(TensorType_FLOAT32);
1674     SetBuiltinOp(op, BuiltinOptions_NONE, 0);
1675     BuildInterpreterWithNNAPI({input_shape});
1676   }
1677 };
1678 
TEST(Elementwise,Abs)1679 TEST(Elementwise, Abs) {
1680   ElementwiseOpFloatModel m(BuiltinOperator_ABS, {1, 2, 4, 1});
1681   m.PopulateTensor<float>(m.input(), {
1682                                          0.f, -6.2f, 2.f, 4.f,  //
1683                                          3.f, -2.f, 10.f, 1.f,  //
1684                                      });
1685   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1686   EXPECT_THAT(m.ExtractVector<float>(m.output()), NnapiArrayFloatNear({
1687                                                       0.f, 6.2f, 2.f, 4.f,  //
1688                                                       3.f, 2.f, 10.f, 1.f,  //
1689                                                   }));
1690   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 2, 4, 1}));
1691 }
1692 
TEST(Elementwise,Exp)1693 TEST(Elementwise, Exp) {
1694   ElementwiseOpFloatModel m(BuiltinOperator_EXP, {3, 1, 2});
1695   m.PopulateTensor<float>(m.input(), {1.0, 0.0, -1.0, 1.0, 1.0, -1.0});
1696   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1697   EXPECT_THAT(
1698       m.ExtractVector<float>(m.output()),
1699       NnapiArrayFloatNear({2.71828, 1, 0.367879, 2.71828, 2.71828, 0.367879}));
1700   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({3, 1, 2}));
1701 }
1702 
TEST(Elementwise,Log)1703 TEST(Elementwise, Log) {
1704   ElementwiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1});
1705   m.PopulateTensor<float>(m.input(), {1, 3.1415926, 1, 1});
1706   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1707   EXPECT_THAT(m.ExtractVector<float>(m.output()),
1708               NnapiArrayFloatNear({0, 1.14473, 0, 0}));
1709   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
1710 }
1711 
TEST(Elementwise,Rsqrt)1712 TEST(Elementwise, Rsqrt) {
1713   ElementwiseOpFloatModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1});
1714   m.PopulateTensor<float>(m.input(), {1, 2, 4, 9});
1715   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1716   EXPECT_THAT(m.ExtractVector<float>(m.output()),
1717               NnapiArrayFloatNear({1, 0.7071, 0.5, 0.33333}));
1718   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
1719 }
1720 
TEST(Elementwise,Sin)1721 TEST(Elementwise, Sin) {
1722   ElementwiseOpFloatModel m(BuiltinOperator_SIN, {1, 1, 4, 1});
1723   m.PopulateTensor<float>(m.input(), {0, 3.1415926, -3.1415926, 1});
1724   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1725   EXPECT_THAT(m.ExtractVector<float>(m.output()),
1726               NnapiArrayFloatNear({0, 0, 0, 0.84147}));
1727   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
1728 }
1729 
TEST(Elementwise,Sqrt)1730 TEST(Elementwise, Sqrt) {
1731   ElementwiseOpFloatModel m(BuiltinOperator_SQRT, {1, 1, 4, 1});
1732   m.PopulateTensor<float>(m.input(), {0, 1, 2, 4});
1733   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1734   EXPECT_THAT(m.ExtractVector<float>(m.output()),
1735               NnapiArrayFloatNear({0, 1, 1.41421, 2}));
1736   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
1737 }
1738 
1739 class FloatSubOpModel : public SingleOpModelWithNNAPI {
1740  public:
FloatSubOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type)1741   FloatSubOpModel(const TensorData& input1, const TensorData& input2,
1742                   const TensorData& output,
1743                   ActivationFunctionType activation_type) {
1744     input1_ = AddInput(input1);
1745     input2_ = AddInput(input2);
1746     output_ = AddOutput(output);
1747     SetBuiltinOp(BuiltinOperator_SUB, BuiltinOptions_SubOptions,
1748                  CreateMulOptions(builder_, activation_type).Union());
1749     BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)});
1750   }
1751 
input1()1752   int input1() { return input1_; }
input2()1753   int input2() { return input2_; }
1754 
GetOutput()1755   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1756 
1757  protected:
1758   int input1_;
1759   int input2_;
1760   int output_;
1761 };
1762 
TEST(NNAPIDelegate,SubWithNoActivation)1763 TEST(NNAPIDelegate, SubWithNoActivation) {
1764   FloatSubOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
1765                     {TensorType_FLOAT32, {1, 2, 2, 1}},
1766                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
1767   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
1768   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
1769   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1770   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-2.1, 0.0, 0.4, 0.3}));
1771 }
1772 
1773 class FloatDivOpModel : public SingleOpModelWithNNAPI {
1774  public:
FloatDivOpModel(const TensorData & input1,const TensorData & input2,const TensorData & output,ActivationFunctionType activation_type)1775   FloatDivOpModel(const TensorData& input1, const TensorData& input2,
1776                   const TensorData& output,
1777                   ActivationFunctionType activation_type) {
1778     input1_ = AddInput(input1);
1779     input2_ = AddInput(input2);
1780     output_ = AddOutput(output);
1781     SetBuiltinOp(BuiltinOperator_DIV, BuiltinOptions_DivOptions,
1782                  CreateMulOptions(builder_, activation_type).Union());
1783     BuildInterpreterWithNNAPI({GetShape(input1_), GetShape(input2_)});
1784   }
1785 
input1()1786   int input1() { return input1_; }
input2()1787   int input2() { return input2_; }
1788 
GetOutput()1789   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1790 
1791  protected:
1792   int input1_;
1793   int input2_;
1794   int output_;
1795 };
1796 
TEST(NNAPIDelegate,DivWithNoActivation)1797 TEST(NNAPIDelegate, DivWithNoActivation) {
1798   FloatDivOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
1799                     {TensorType_FLOAT32, {1, 2, 2, 1}},
1800                     {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
1801   m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.8, 0.8});
1802   m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.4, 0.2});
1803   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1804   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({-20, 1, 2, 4}));
1805 }
1806 
1807 class BaseConcatenationOpModel : public SingleOpModelWithNNAPI {
1808  public:
BaseConcatenationOpModel()1809   BaseConcatenationOpModel() {}
BaseConcatenationOpModel(const TensorData & input_template,int axis,int num_inputs)1810   BaseConcatenationOpModel(const TensorData& input_template, int axis,
1811                            int num_inputs) {
1812     std::vector<std::vector<int>> all_input_shapes;
1813     for (int i = 0; i < num_inputs; ++i) {
1814       all_input_shapes.push_back(input_template.shape);
1815       AddInput(input_template);
1816     }
1817     output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min,
1818                          input_template.max});
1819     SetBuiltinOp(
1820         BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
1821         CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
1822             .Union());
1823     BuildInterpreterWithNNAPI(all_input_shapes);
1824   }
1825 
1826  protected:
1827   int output_;
1828 };
1829 
1830 class ConcatenationOpModel : public BaseConcatenationOpModel {
1831  public:
1832   using BaseConcatenationOpModel::BaseConcatenationOpModel;
SetInput(int index,std::initializer_list<float> data)1833   void SetInput(int index, std::initializer_list<float> data) {
1834     PopulateTensor(index, data);
1835   }
GetOutput()1836   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1837 };
1838 
TEST(NNAPIDelegate,ConcatenationThreeDimensionalOneInput)1839 TEST(NNAPIDelegate, ConcatenationThreeDimensionalOneInput) {
1840   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1,
1841                           /*num_inputs=*/1);
1842   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
1843   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
1844   EXPECT_THAT(m0.GetOutput(), NnapiArrayFloatNear({1, 3, 4, 7}));
1845 }
1846 
TEST(NNAPIDelegate,ConcatenationFourInputs)1847 TEST(NNAPIDelegate, ConcatenationFourInputs) {
1848   ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2,
1849                           /*num_inputs=*/4);
1850   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
1851   m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
1852   m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
1853   m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
1854   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
1855   EXPECT_THAT(m0.GetOutput(),
1856               NnapiArrayFloatNear({
1857                   1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
1858                   4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
1859               }));
1860 }
1861 
1862 class QuantizedConcatenationOpModel : public BaseConcatenationOpModel {
1863  public:
1864   using BaseConcatenationOpModel::BaseConcatenationOpModel;
QuantizedConcatenationOpModel(const std::vector<TensorData> & input_template,int axis,int num_inputs,const TensorData & output_template)1865   QuantizedConcatenationOpModel(const std::vector<TensorData>& input_template,
1866                                 int axis, int num_inputs,
1867                                 const TensorData& output_template) {
1868     std::vector<std::vector<int>> all_input_shapes;
1869     CHECK_EQ(input_template.size(), num_inputs);
1870     for (int i = 0; i < num_inputs; ++i) {
1871       all_input_shapes.push_back(input_template[i].shape);
1872       AddInput(input_template[i]);
1873     }
1874     output_ = AddOutput({output_template.type, /*shape=*/{},
1875                          output_template.min, output_template.max});
1876     SetBuiltinOp(
1877         BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
1878         CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
1879             .Union());
1880     BuildInterpreterWithNNAPI(all_input_shapes);
1881   }
SetInput(int index,std::initializer_list<float> data)1882   void SetInput(int index, std::initializer_list<float> data) {
1883     QuantizeAndPopulate<uint8_t>(index, data);
1884   }
GetOutput()1885   std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
GetDequantizedOutput()1886   std::vector<float> GetDequantizedOutput() {
1887     return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
1888                                GetScale(output_), GetZeroPoint(output_));
1889   }
1890 };
1891 
TEST(NNAPIDelegate,ConcatenationFourInputsQuantized)1892 TEST(NNAPIDelegate, ConcatenationFourInputsQuantized) {
1893   QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8},
1894                                    /*axis=*/2,
1895                                    /*num_inputs=*/4);
1896 
1897   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
1898   m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
1899   m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
1900   m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
1901   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
1902   EXPECT_THAT(m0.GetDequantizedOutput(),
1903               ElementsAreArray(ArrayFloatNear({
1904                   1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
1905                   4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
1906               })));
1907   EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
1908                                   137, 157, 138, 158, 139, 159, 140, 160,  //
1909                                   167, 197, 168, 198, 169, 199, 170, 200,  //
1910                               }));
1911 }
1912 
TEST(NNAPIDelegate,ConcatenationFourInputsQuantizedMixedRange)1913 TEST(NNAPIDelegate, ConcatenationFourInputsQuantizedMixedRange) {
1914   QuantizedConcatenationOpModel m0({{TensorType_UINT8, {2, 1, 2}, -10.7, 10.8},
1915                                     {TensorType_UINT8, {2, 1, 2}, 0, 12.8},
1916                                     {TensorType_UINT8, {2, 1, 2}, -11, 11.8},
1917                                     {TensorType_UINT8, {2, 1, 2}, 0, 7.4}},
1918                                    /*axis=*/2, /*num_inputs=*/4,
1919                                    {TensorType_UINT8, {2, 1, 2}, -12.7, 12.8});
1920 
1921   m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
1922   m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
1923   m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
1924   m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
1925   ASSERT_EQ(m0.Invoke(), kTfLiteOk);
1926   EXPECT_THAT(m0.GetDequantizedOutput(),
1927               ElementsAreArray(ArrayFloatNear({
1928                   1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f,  //
1929                   4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f,  //
1930               })));
1931   EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
1932                                   137, 157, 138, 158, 139, 159, 140, 160,  //
1933                                   167, 197, 168, 198, 169, 199, 170, 200,  //
1934                               }));
1935 }
1936 
1937 class DequantizeOpModel : public SingleOpModelWithNNAPI {
1938  public:
DequantizeOpModel(TensorType inputType,std::initializer_list<int> shape,float min,float max)1939   DequantizeOpModel(TensorType inputType, std::initializer_list<int> shape,
1940                     float min, float max) {
1941     input_ = AddInput({inputType, shape, min, max});
1942     output_ = AddOutput({TensorType_FLOAT32, shape});
1943     SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions,
1944                  CreateDequantizeOptions(builder_).Union());
1945 
1946     BuildInterpreterWithNNAPI({GetShape(input_)});
1947   }
1948 
1949   template <typename T>
SetInput(std::initializer_list<T> data)1950   void SetInput(std::initializer_list<T> data) {
1951     PopulateTensor(input_, data);
1952   }
1953 
GetOutput()1954   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
1955 
1956  private:
1957   int input_;
1958   int output_;
1959 };
1960 
TEST(NNAPIDelegate,DequantizeFourDimensionalUint8)1961 TEST(NNAPIDelegate, DequantizeFourDimensionalUint8) {
1962   DequantizeOpModel m(TensorType_UINT8, {2, 5}, -63.5, 64);
1963 
1964   m.SetInput<uint8_t>({0, 1, 2, 3, 4, 251, 252, 253, 254, 255});
1965   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1966   EXPECT_THAT(m.GetOutput(),
1967               ElementsAreArray(ArrayFloatNear(
1968                   {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64})));
1969 }
1970 
TEST(NNAPIDelegate,DequantizeFourDimensionalInt8Symm)1971 TEST(NNAPIDelegate, DequantizeFourDimensionalInt8Symm) {
1972   // [-64, 63.5] -> scale=0.5, zero_point=0 for INT8
1973   DequantizeOpModel m(TensorType_INT8, {2, 5}, -64, 63.5);
1974 
1975   m.SetInput<int8_t>({-128, -127, -126, -125, -124, 123, 124, 125, 126, 127});
1976   ASSERT_EQ(m.Invoke(), kTfLiteOk);
1977   EXPECT_THAT(m.GetOutput(),
1978               ElementsAreArray(ArrayFloatNear(
1979                   {-64, -63.5, -63, -62.5, -62, 61.5, 62, 62.5, 63, 63.5})));
1980 }
1981 
1982 class FloorOpModel : public SingleOpModelWithNNAPI {
1983  public:
FloorOpModel(std::initializer_list<int> input_shape,TensorType input_type)1984   FloorOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
1985     input_ = AddInput(TensorType_FLOAT32);
1986     output_ = AddOutput(TensorType_FLOAT32);
1987     SetBuiltinOp(BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0);
1988     BuildInterpreterWithNNAPI({
1989         input_shape,
1990     });
1991   }
1992 
input()1993   int input() { return input_; }
1994 
GetOutput()1995   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()1996   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
1997 
1998  private:
1999   int input_;
2000   int output_;
2001 };
2002 
TEST(NNAPIDelegate,FloorSingleDim)2003 TEST(NNAPIDelegate, FloorSingleDim) {
2004   FloorOpModel model({2}, TensorType_FLOAT32);
2005   model.PopulateTensor<float>(model.input(), {8.5, 0.0});
2006   ASSERT_EQ(model.Invoke(), kTfLiteOk);
2007   EXPECT_THAT(model.GetOutput(), NnapiArrayFloatNear({8, 0}));
2008   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
2009 }
2010 
TEST(NNAPIDelegate,FloorMultiDims)2011 TEST(NNAPIDelegate, FloorMultiDims) {
2012   FloorOpModel model({2, 1, 1, 5}, TensorType_FLOAT32);
2013   model.PopulateTensor<float>(model.input(), {
2014                                                  0.0001,
2015                                                  8.0001,
2016                                                  0.9999,
2017                                                  9.9999,
2018                                                  0.5,
2019                                                  -0.0001,
2020                                                  -8.0001,
2021                                                  -0.9999,
2022                                                  -9.9999,
2023                                                  -0.5,
2024                                              });
2025   ASSERT_EQ(model.Invoke(), kTfLiteOk);
2026   EXPECT_THAT(model.GetOutput(),
2027               NnapiArrayFloatNear({0, 8, 0, 9, 0, -1, -9, -1, -10, -1}));
2028   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5}));
2029 }
2030 
2031 class LocalResponseNormOpModel : public SingleOpModelWithNNAPI {
2032  public:
LocalResponseNormOpModel(std::initializer_list<int> input_shape,int radius,float bias,float alpha,float beta)2033   LocalResponseNormOpModel(std::initializer_list<int> input_shape, int radius,
2034                            float bias, float alpha, float beta) {
2035     input_ = AddInput(TensorType_FLOAT32);
2036     output_ = AddOutput(TensorType_FLOAT32);
2037     SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
2038                  BuiltinOptions_LocalResponseNormalizationOptions,
2039                  CreateLocalResponseNormalizationOptions(builder_, radius, bias,
2040                                                          alpha, beta)
2041                      .Union());
2042     BuildInterpreterWithNNAPI({input_shape});
2043   }
2044 
SetInput(std::initializer_list<float> data)2045   void SetInput(std::initializer_list<float> data) {
2046     PopulateTensor(input_, data);
2047   }
2048 
GetOutput()2049   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2050 
2051  private:
2052   int input_;
2053   int output_;
2054 };
2055 
TEST(NNAPIDelegate,LocalResponseNormSameAsL2Norm)2056 TEST(NNAPIDelegate, LocalResponseNormSameAsL2Norm) {
2057   LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
2058                              /*alpha=*/1.0, /*beta=*/0.5);
2059   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
2060   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2061   // The result is every input divided by 2.
2062   EXPECT_THAT(m.GetOutput(),
2063               NnapiArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
2064 }
2065 
TEST(NNAPIDelegate,LocalResponseNormWithAlpha)2066 TEST(NNAPIDelegate, LocalResponseNormWithAlpha) {
2067   LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0,
2068                              /*alpha=*/4.0, /*beta=*/0.5);
2069   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
2070   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2071   // The result is every input divided by 3.
2072   EXPECT_THAT(m.GetOutput(),
2073               NnapiArrayFloatNear({-0.275, 0.15, 0.175, 0.3, -0.175, 0.025}));
2074 }
2075 
TEST(NNAPIDelegate,LocalResponseNormWithBias)2076 TEST(NNAPIDelegate, LocalResponseNormWithBias) {
2077   LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0,
2078                              /*alpha=*/4.0, /*beta=*/0.5);
2079   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
2080   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2081   // The result is every input divided by 5.
2082   EXPECT_THAT(m.GetOutput(),
2083               NnapiArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02}));
2084 }
2085 
TEST(NNAPIDelegate,LocalResponseNormSmallRadius)2086 TEST(NNAPIDelegate, LocalResponseNormSmallRadius) {
2087   LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0,
2088                              /*alpha=*/4.0, /*beta=*/0.5);
2089   m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
2090   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2091   EXPECT_THAT(m.GetOutput(),
2092               NnapiArrayFloatNear({-0.264926, 0.125109, 0.140112, 0.267261,
2093                                    -0.161788, 0.0244266}));
2094 }
2095 
2096 class LSHProjectionOpModel : public SingleOpModelWithNNAPI {
2097  public:
LSHProjectionOpModel(LSHProjectionType type,std::initializer_list<int> hash_shape,std::initializer_list<int> input_shape,std::initializer_list<int> weight_shape)2098   LSHProjectionOpModel(LSHProjectionType type,
2099                        std::initializer_list<int> hash_shape,
2100                        std::initializer_list<int> input_shape,
2101                        std::initializer_list<int> weight_shape) {
2102     hash_ = AddInput(TensorType_FLOAT32);
2103     input_ = AddInput(TensorType_INT32);
2104     if (weight_shape.size() > 0) {
2105       weight_ = AddInput(TensorType_FLOAT32);
2106     }
2107     output_ = AddOutput(TensorType_INT32);
2108 
2109     SetBuiltinOp(BuiltinOperator_LSH_PROJECTION,
2110                  BuiltinOptions_LSHProjectionOptions,
2111                  CreateLSHProjectionOptions(builder_, type).Union());
2112     if (weight_shape.size() > 0) {
2113       BuildInterpreterWithNNAPI({hash_shape, input_shape, weight_shape});
2114     } else {
2115       BuildInterpreterWithNNAPI({hash_shape, input_shape});
2116     }
2117 
2118     output_size_ = 1;
2119     for (int i : hash_shape) {
2120       output_size_ *= i;
2121       if (type == LSHProjectionType_SPARSE) {
2122         break;
2123       }
2124     }
2125   }
SetInput(std::initializer_list<int> data)2126   void SetInput(std::initializer_list<int> data) {
2127     PopulateTensor(input_, data);
2128   }
2129 
SetHash(std::initializer_list<float> data)2130   void SetHash(std::initializer_list<float> data) {
2131     PopulateTensor(hash_, data);
2132   }
2133 
SetWeight(std::initializer_list<float> f)2134   void SetWeight(std::initializer_list<float> f) { PopulateTensor(weight_, f); }
2135 
GetOutput()2136   std::vector<int> GetOutput() { return ExtractVector<int>(output_); }
2137 
2138  private:
2139   int input_;
2140   int hash_;
2141   int weight_;
2142   int output_;
2143 
2144   int output_size_;
2145 };
2146 
TEST(NNAPIDelegate,LSHProjectionDense1DInputs)2147 TEST(NNAPIDelegate, LSHProjectionDense1DInputs) {
2148   LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5});
2149 
2150   m.SetInput({12345, 54321, 67890, 9876, -12345678});
2151   m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
2152   m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0});
2153 
2154   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2155 
2156 #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
2157     __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
2158   // Hash returns differently on machines with different endianness
2159   EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 1, 1, 1, 0));
2160 #else
2161   EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0));
2162 #endif
2163 }
2164 
TEST(NNAPIDelegate,LSHProjectionSparse1DInputs)2165 TEST(NNAPIDelegate, LSHProjectionSparse1DInputs) {
2166   LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {});
2167 
2168   m.SetInput({12345, 54321, 67890, 9876, -12345678});
2169   m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
2170 
2171   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2172 
2173 #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
2174     __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
2175   // Hash returns differently on machines with different endianness
2176   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 3, 8 + 2));
2177 #else
2178   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0));
2179 #endif
2180 }
2181 
TEST(NNAPIDelegate,LSHProjectionSparse3DInputs)2182 TEST(NNAPIDelegate, LSHProjectionSparse3DInputs) {
2183   LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5});
2184 
2185   m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912,
2186               9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543});
2187   m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321});
2188   m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78});
2189 
2190   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2191 
2192 #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \
2193     __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
2194   // Hash returns differently on machines with different endianness
2195   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 3, 8 + 2));
2196 #else
2197   EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1));
2198 #endif
2199 }
2200 
2201 class BaseActivationsOpModel : public SingleOpModelWithNNAPI {
2202  public:
2203   // Most activations don't take any options, so this constructor works for
2204   // them.
BaseActivationsOpModel(BuiltinOperator type,const TensorData & input)2205   BaseActivationsOpModel(BuiltinOperator type, const TensorData& input) {
2206     input_ = AddInput(input);
2207     if (input.type == TensorType_UINT8) {
2208       output_ = AddOutput({input.type, {}, 0, 0, 1. / 256});
2209     } else {
2210       output_ = AddOutput({input.type, {}});
2211     }
2212     SetBuiltinOp(type, BuiltinOptions_NONE, 0);
2213     BuildInterpreterWithNNAPI({GetShape(input_)});
2214   }
2215 
BaseActivationsOpModel(BuiltinOperator type,const TensorData & input,const TensorData & output)2216   BaseActivationsOpModel(BuiltinOperator type, const TensorData& input,
2217                          const TensorData& output) {
2218     input_ = AddInput(input);
2219     output_ = AddOutput(output);
2220     SetBuiltinOp(type, BuiltinOptions_NONE, 0);
2221     BuildInterpreterWithNNAPI({GetShape(input_)});
2222   }
2223 
2224  protected:
2225   int input_;
2226   int output_;
2227 };
2228 
2229 class FloatActivationsOpModel : public BaseActivationsOpModel {
2230  public:
2231   using BaseActivationsOpModel::BaseActivationsOpModel;
2232 
SetInput(std::initializer_list<float> data)2233   void SetInput(std::initializer_list<float> data) {
2234     PopulateTensor(input_, data);
2235   }
GetOutput()2236   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2237 };
2238 
2239 const float kQuantizedTolerance = 2 * (1. / 256);
2240 
2241 class QuantizedActivationsOpModel : public BaseActivationsOpModel {
2242  public:
2243   using BaseActivationsOpModel::BaseActivationsOpModel;
2244 
2245   template <typename T>
SetInput(std::initializer_list<float> data)2246   void SetInput(std::initializer_list<float> data) {
2247     QuantizeAndPopulate<T>(input_, data);
2248   }
2249   template <typename T>
2250 
GetOutput()2251   std::vector<T> GetOutput() {
2252     return ExtractVector<T>(output_);
2253   }
2254   template <typename T>
GetDequantizedOutput()2255   std::vector<float> GetDequantizedOutput() {
2256     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
2257                          GetZeroPoint(output_));
2258   }
2259 };
2260 
TEST(NNAPIDelegate,Relu)2261 TEST(NNAPIDelegate, Relu) {
2262   FloatActivationsOpModel m(BuiltinOperator_RELU,
2263                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
2264   m.SetInput({
2265       0, -6, 2, 4,   //
2266       3, -2, 10, 1,  //
2267   });
2268   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2269   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
2270                                  0, 0, 2, 4,   //
2271                                  3, 0, 10, 1,  //
2272                              }));
2273 }
2274 
TEST(NNAPIDelegate,Relu1)2275 TEST(NNAPIDelegate, Relu1) {
2276   FloatActivationsOpModel m(BuiltinOperator_RELU_N1_TO_1,
2277                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
2278   m.SetInput({
2279       0.0, -0.6, 0.2, -0.4,  //
2280       0.3, -2.0, 1.1, -0.1,  //
2281   });
2282   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2283   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
2284                                  0.0, -0.6, 0.2, -0.4,  //
2285                                  0.3, -1.0, 1.0, -0.1,  //
2286                              }));
2287 }
2288 
TEST(NNAPIDelegate,Relu6)2289 TEST(NNAPIDelegate, Relu6) {
2290   FloatActivationsOpModel m(BuiltinOperator_RELU6,
2291                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
2292   m.SetInput({
2293       0, -6, 2, 4,   //
2294       3, -2, 10, 1,  //
2295   });
2296   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2297   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
2298                                  0, 0, 2, 4,  //
2299                                  3, 0, 6, 1,  //
2300                              }));
2301 }
2302 
TEST(NNAPIDelegate,LogisticFloat)2303 TEST(NNAPIDelegate, LogisticFloat) {
2304   FloatActivationsOpModel m(BuiltinOperator_LOGISTIC,
2305                             /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
2306   m.SetInput({
2307       0, -6, 2, 4,   //
2308       3, -2, 10, 1,  //
2309   });
2310   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2311   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
2312                                  0.5, 0.002473, 0.880797, 0.982014,       //
2313                                  0.952574, 0.119203, 0.999955, 0.731059,  //
2314                              }));
2315 }
2316 
TEST(NNAPIDelegate,LogisticQuantized)2317 TEST(NNAPIDelegate, LogisticQuantized) {
2318   QuantizedActivationsOpModel m(
2319       BuiltinOperator_LOGISTIC,
2320       /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10});
2321   m.SetInput<uint8_t>({
2322       0, -6, 2, 4,   //
2323       3, -2, 10, 1,  //
2324   });
2325   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2326   EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
2327               ElementsAreArray(ArrayFloatNear(
2328                   {
2329                       0.5, 0.002473, 0.880797, 0.982014,       //
2330                       0.952574, 0.119203, 0.999955, 0.731059,  //
2331                   },
2332                   kQuantizedTolerance)));
2333   EXPECT_THAT(m.GetOutput<uint8_t>(),
2334               testing::Pointwise(QuantizedNear(),
2335                                  {128, 1, 227, 251, 244, 32, 255, 188}));
2336 }
2337 
2338 class ResizeBilinearOpModel : public SingleOpModelWithNNAPI {
2339  public:
ResizeBilinearOpModel(const TensorData & input,std::initializer_list<int> size_data)2340   ResizeBilinearOpModel(const TensorData& input,
2341                         std::initializer_list<int> size_data) {
2342     bool const_size = size_data.size() != 0;
2343     input_ = AddInput(input);
2344     if (const_size) {
2345       size_ = AddConstInput(TensorType_INT32, size_data, {2});
2346     } else {
2347       size_ = AddInput({TensorType_INT32, {2}});
2348     }
2349     output_ = AddOutput(input.type);
2350     SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
2351                  BuiltinOptions_ResizeBilinearOptions,
2352                  CreateResizeBilinearOptions(builder_).Union());
2353     if (const_size) {
2354       BuildInterpreterWithNNAPI({GetShape(input_)});
2355     } else {
2356       BuildInterpreterWithNNAPI({GetShape(input_), GetShape(size_)});
2357     }
2358   }
2359 
2360   template <typename T>
SetInput(std::initializer_list<T> data)2361   void SetInput(std::initializer_list<T> data) {
2362     PopulateTensor(input_, data);
2363   }
SetSize(std::initializer_list<int> data)2364   void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
2365 
2366   template <typename T>
GetOutput()2367   std::vector<T> GetOutput() {
2368     return ExtractVector<T>(output_);
2369   }
2370 
2371  private:
2372   int input_;
2373   int size_;
2374   int output_;
2375 };
2376 
TEST(ResizeBilinear,Horizontal)2377 TEST(ResizeBilinear, Horizontal) {
2378   ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}, {});
2379   m.SetInput<float>({3, 6});
2380   m.SetSize({1, 3});
2381   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2382   EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({3, 5, 6}));
2383 }
2384 
TEST(ResizeBilinear,HorizontalConstant)2385 TEST(ResizeBilinear, HorizontalConstant) {
2386   ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
2387   const_m.SetInput<float>({3, 6});
2388   ASSERT_EQ(const_m.Invoke(), kTfLiteOk);
2389   EXPECT_THAT(const_m.GetOutput<float>(), NnapiArrayFloatNear({3, 5, 6}));
2390 }
2391 
TEST(ResizeBilinear,Vertical)2392 TEST(ResizeBilinear, Vertical) {
2393   ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}, {});
2394   m.SetInput<float>({3, 9});
2395   m.SetSize({3, 1});
2396   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2397   EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({3, 7, 9}));
2398 }
2399 
TEST(ResizeBilinear,VerticalConstant)2400 TEST(ResizeBilinear, VerticalConstant) {
2401   ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
2402   const_m.SetInput<float>({3, 9});
2403   ASSERT_EQ(const_m.Invoke(), kTfLiteOk);
2404   EXPECT_THAT(const_m.GetOutput<float>(), NnapiArrayFloatNear({3, 7, 9}));
2405 }
2406 
TEST(ResizeBilinear,TwoDimensional)2407 TEST(ResizeBilinear, TwoDimensional) {
2408   ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {});
2409   m.SetInput<float>({
2410       3, 6,  //
2411       9, 12  //
2412   });
2413   m.SetSize({3, 3});
2414   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2415   EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({
2416                                         3, 5, 6,    //
2417                                         7, 9, 10,   //
2418                                         9, 11, 12,  //
2419                                     }));
2420 }
2421 
TEST(ResizeBilinear,TwoDimensionalConstant)2422 TEST(ResizeBilinear, TwoDimensionalConstant) {
2423   ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
2424   const_m.SetInput<float>({
2425       3, 6,  //
2426       9, 12  //
2427   });
2428   ASSERT_EQ(const_m.Invoke(), kTfLiteOk);
2429   EXPECT_THAT(const_m.GetOutput<float>(), NnapiArrayFloatNear({
2430                                               3, 5, 6,    //
2431                                               7, 9, 10,   //
2432                                               9, 11, 12,  //
2433                                           }));
2434 }
2435 
2436 template <typename T>
2437 class PadOpModel : public SingleOpModelWithNNAPI {
2438  public:
SetInput(std::initializer_list<T> data)2439   void SetInput(std::initializer_list<T> data) {
2440     PopulateTensor<T>(input_, data);
2441   }
2442 
2443   template <typename QuantizedInputOutput>
SetQuantizedInput(std::initializer_list<float> data)2444   void SetQuantizedInput(std::initializer_list<float> data) {
2445     QuantizeAndPopulate<QuantizedInputOutput>(input_, data);
2446   }
2447 
2448   template <typename QuantizedInputOutput>
SetQuantizedPadValue(float data)2449   void SetQuantizedPadValue(float data) {
2450     QuantizeAndPopulate<QuantizedInputOutput>(constant_values_, {data});
2451   }
2452 
SetPaddings(std::initializer_list<int> paddings)2453   void SetPaddings(std::initializer_list<int> paddings) {
2454     PopulateTensor<int>(paddings_, paddings);
2455   }
2456 
GetOutput()2457   std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetOutputShape()2458   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
2459 
2460   template <typename QuantizedInputOutput>
GetDequantizedOutput()2461   std::vector<float> GetDequantizedOutput() {
2462     return Dequantize<QuantizedInputOutput>(
2463         ExtractVector<QuantizedInputOutput>(output_), GetScale(output_),
2464         GetZeroPoint(output_));
2465   }
2466 
2467  protected:
2468   int input_;
2469   int output_;
2470   int paddings_;
2471   int constant_values_;
2472 };
2473 
2474 class PadOpConstModel : public PadOpModel<float> {
2475  public:
PadOpConstModel(const TensorData & input,std::initializer_list<int> paddings_shape,std::initializer_list<int> paddings,const TensorData & output)2476   PadOpConstModel(const TensorData& input,
2477                   std::initializer_list<int> paddings_shape,
2478                   std::initializer_list<int> paddings,
2479                   const TensorData& output) {
2480     input_ = AddInput(input);
2481     paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape);
2482     output_ = AddOutput(output);
2483 
2484     SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
2485                  CreatePadOptions(builder_).Union());
2486     BuildInterpreterWithNNAPI({input.shape});
2487   }
2488 };
2489 
TEST(NNAPIDelegate,PadAdvancedConstTest)2490 TEST(NNAPIDelegate, PadAdvancedConstTest) {
2491   PadOpConstModel m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2},
2492                     {0, 0, 0, 2, 1, 3, 0, 0}, {TensorType_FLOAT32});
2493   m.SetInput({1, 2, 3, 4, 5, 6});
2494   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2495   EXPECT_THAT(m.GetOutput(),
2496               NnapiArrayFloatNear({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
2497                                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
2498   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
2499 }
2500 
2501 class SpaceToBatchNDOpModel : public SingleOpModelWithNNAPI {
2502  public:
SetInput(std::initializer_list<float> data)2503   void SetInput(std::initializer_list<float> data) {
2504     PopulateTensor<float>(input_, data);
2505   }
2506 
SetBlockShape(std::initializer_list<int> data)2507   void SetBlockShape(std::initializer_list<int> data) {
2508     PopulateTensor<int>(block_shape_, data);
2509   }
2510 
SetPaddings(std::initializer_list<int> data)2511   void SetPaddings(std::initializer_list<int> data) {
2512     PopulateTensor<int>(paddings_, data);
2513   }
2514 
GetOutput()2515   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()2516   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
2517 
2518  protected:
2519   int input_;
2520   int block_shape_;
2521   int paddings_;
2522   int output_;
2523 };
2524 
2525 class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel {
2526  public:
SpaceToBatchNDOpConstModel(std::initializer_list<int> input_shape,std::initializer_list<int> block_shape,std::initializer_list<int> paddings)2527   SpaceToBatchNDOpConstModel(std::initializer_list<int> input_shape,
2528                              std::initializer_list<int> block_shape,
2529                              std::initializer_list<int> paddings) {
2530     input_ = AddInput(TensorType_FLOAT32);
2531     block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
2532     paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2});
2533     output_ = AddOutput(TensorType_FLOAT32);
2534 
2535     SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND,
2536                  BuiltinOptions_SpaceToBatchNDOptions,
2537                  CreateSpaceToBatchNDOptions(builder_).Union());
2538     BuildInterpreterWithNNAPI({input_shape});
2539   }
2540 };
2541 
TEST(NNAPIDelegate,SpaceToBatchNDSimpleConstTest)2542 TEST(NNAPIDelegate, SpaceToBatchNDSimpleConstTest) {
2543   SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0});
2544   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
2545   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2546   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1}));
2547   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 3, 9, 11, 2, 4, 10, 12, 5,
2548                                                   7, 13, 15, 6, 8, 14, 16}));
2549 }
2550 
TEST(NNAPIDelegate,SpaceToBatchNDMultipleInputBatchesConstTest)2551 TEST(NNAPIDelegate, SpaceToBatchNDMultipleInputBatchesConstTest) {
2552   SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0});
2553   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
2554   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2555   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1}));
2556   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 3, 9, 11, 2, 4, 10, 12, 5,
2557                                                   7, 13, 15, 6, 8, 14, 16}));
2558 }
2559 
TEST(NNAPIDelegate,SpaceToBatchNDSimplePaddingConstTest)2560 TEST(NNAPIDelegate, SpaceToBatchNDSimplePaddingConstTest) {
2561   SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0});
2562   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
2563   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2564   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
2565   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
2566                                  0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7,
2567                                  0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10,
2568                              }));
2569 }
2570 
TEST(NNAPIDelegate,SpaceToBatchNDComplexPaddingConstTest)2571 TEST(NNAPIDelegate, SpaceToBatchNDComplexPaddingConstTest) {
2572   SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4});
2573   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
2574   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2575   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
2576   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
2577                                  0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0,
2578                                  0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0,
2579                                  0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0,
2580                              }));
2581 }
2582 
2583 template <typename input_type = float,
2584           TensorType tensor_input_type = TensorType_FLOAT32>
2585 class StridedSliceOpModel : public SingleOpModelWithNNAPI {
2586  public:
StridedSliceOpModel(std::initializer_list<int> input_shape,std::initializer_list<int> begin_shape,std::initializer_list<int> begin_data,std::initializer_list<int> end_shape,std::initializer_list<int> end_data,std::initializer_list<int> strides_shape,std::initializer_list<int> strides_data,int begin_mask,int end_mask,int ellipsis_mask,int new_axis_mask,int shrink_axis_mask)2587   StridedSliceOpModel(std::initializer_list<int> input_shape,
2588                       std::initializer_list<int> begin_shape,
2589                       std::initializer_list<int> begin_data,
2590                       std::initializer_list<int> end_shape,
2591                       std::initializer_list<int> end_data,
2592                       std::initializer_list<int> strides_shape,
2593                       std::initializer_list<int> strides_data, int begin_mask,
2594                       int end_mask, int ellipsis_mask, int new_axis_mask,
2595                       int shrink_axis_mask) {
2596     input_ = AddInput(tensor_input_type);
2597     begin_ = AddConstInput(TensorType_INT32, begin_data, begin_shape);
2598     end_ = AddConstInput(TensorType_INT32, end_data, end_shape);
2599     strides_ = AddConstInput(TensorType_INT32, strides_data, strides_shape);
2600     output_ = AddOutput(tensor_input_type);
2601     SetBuiltinOp(
2602         BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions,
2603         CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask,
2604                                   new_axis_mask, shrink_axis_mask)
2605             .Union());
2606     BuildInterpreterWithNNAPI(
2607         {input_shape, begin_shape, end_shape, strides_shape});
2608   }
2609 
SetInput(std::initializer_list<input_type> data)2610   void SetInput(std::initializer_list<input_type> data) {
2611     PopulateTensor<input_type>(input_, data);
2612   }
2613 
GetOutput()2614   std::vector<input_type> GetOutput() {
2615     return ExtractVector<input_type>(output_);
2616   }
GetOutputShape()2617   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
2618 
2619  private:
2620   int input_;
2621   int begin_;
2622   int end_;
2623   int strides_;
2624   int output_;
2625 };
2626 
TEST(StridedSliceOpTest,In1D)2627 TEST(StridedSliceOpTest, In1D) {
2628   StridedSliceOpModel<> m({4}, {1}, {1}, {1}, {3}, {1}, {1}, 0, 0, 0, 0, 0);
2629   m.SetInput({1, 2, 3, 4});
2630   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2631   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
2632   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({2, 3}));
2633 }
2634 
TEST(StridedSliceOpTest,In1D_BeginMask)2635 TEST(StridedSliceOpTest, In1D_BeginMask) {
2636   StridedSliceOpModel<> m({4}, {1}, {1}, {1}, {3}, {1}, {1}, 1, 0, 0, 0, 0);
2637   m.SetInput({1, 2, 3, 4});
2638   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2639   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
2640   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 2, 3}));
2641 }
2642 
TEST(StridedSliceOpTest,In2D_Stride2)2643 TEST(StridedSliceOpTest, In2D_Stride2) {
2644   StridedSliceOpModel<> m({2, 3}, {2}, {0, 0}, {2}, {2, 3}, {2}, {2, 2}, 0, 0,
2645                           0, 0, 0);
2646   m.SetInput({1, 2, 3, 4, 5, 6});
2647   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2648   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
2649   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 3}));
2650 }
2651 
TEST(StridedSliceOpTest,In2D_EndMask)2652 TEST(StridedSliceOpTest, In2D_EndMask) {
2653   StridedSliceOpModel<> m({2, 3}, {2}, {1, 0}, {2}, {2, 2}, {2}, {1, 1}, 0, 2,
2654                           0, 0, 0);
2655   m.SetInput({1, 2, 3, 4, 5, 6});
2656   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2657   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3}));
2658   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({4, 5, 6}));
2659 }
2660 
TEST(StridedSliceOpTest,In3D_IdentityShrinkAxis4)2661 TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) {
2662   StridedSliceOpModel<> m({2, 3, 2}, {3}, {0, 0, 0}, {3}, {2, 3, 1}, {3},
2663                           {1, 1, 1}, 0, 0, 0, 0, 4);
2664   m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
2665   ASSERT_EQ(m.Invoke(), kTfLiteOk);
2666   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
2667   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({1, 3, 5, 7, 9, 11}));
2668 }
2669 
2670 static float rnn_input[] = {
2671     0.23689353,   0.285385,     0.037029743, -0.19858193,  -0.27569133,
2672     0.43773448,   0.60379338,   0.35562468,  -0.69424844,  -0.93421471,
2673     -0.87287879,  0.37144363,   -0.62476718, 0.23791671,   0.40060222,
2674     0.1356622,    -0.99774903,  -0.98858172, -0.38952237,  -0.47685933,
2675     0.31073618,   0.71511042,   -0.63767755, -0.31729108,  0.33468103,
2676     0.75801885,   0.30660987,   -0.37354088, 0.77002847,   -0.62747043,
2677     -0.68572164,  0.0069220066, 0.65791464,  0.35130811,   0.80834007,
2678     -0.61777675,  -0.21095741,  0.41213346,  0.73784804,   0.094794154,
2679     0.47791874,   0.86496925,   -0.53376222, 0.85315156,   0.10288584,
2680     0.86684,      -0.011186242, 0.10513687,  0.87825835,   0.59929144,
2681     0.62827742,   0.18899453,   0.31440187,  0.99059987,   0.87170351,
2682     -0.35091716,  0.74861872,   0.17831337,  0.2755419,    0.51864719,
2683     0.55084288,   0.58982027,   -0.47443086, 0.20875752,   -0.058871567,
2684     -0.66609079,  0.59098077,   0.73017097,  0.74604273,   0.32882881,
2685     -0.17503482,  0.22396147,   0.19379807,  0.29120302,   0.077113032,
2686     -0.70331609,  0.15804303,   -0.93407321, 0.40182066,   0.036301374,
2687     0.66521823,   0.0300982,    -0.7747041,  -0.02038002,  0.020698071,
2688     -0.90300065,  0.62870288,   -0.23068321, 0.27531278,   -0.095755219,
2689     -0.712036,    -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,
2690     0.43519354,   0.14744234,   0.62589407,  0.1653645,    -0.10651493,
2691     -0.045277178, 0.99032974,   -0.88255352, -0.85147917,  0.28153265,
2692     0.19455957,   -0.55479527,  -0.56042433, 0.26048636,   0.84702539,
2693     0.47587705,   -0.074295521, -0.12287641, 0.70117295,   0.90532446,
2694     0.89782166,   0.79817224,   0.53402734,  -0.33286154,  0.073485017,
2695     -0.56172788,  -0.044897556, 0.89964068,  -0.067662835, 0.76863563,
2696     0.93455386,   -0.6324693,   -0.083922029};
2697 
2698 static float rnn_golden_output[] = {
2699     0.496726,   0,          0.965996,  0,         0.0584254, 0,
2700     0,          0.12315,    0,         0,         0.612266,  0.456601,
2701     0,          0.52286,    1.16099,   0.0291232,
2702 
2703     0,          0,          0.524901,  0,         0,         0,
2704     0,          1.02116,    0,         1.35762,   0,         0.356909,
2705     0.436415,   0.0355727,  0,         0,
2706 
2707     0,          0,          0,         0.262335,  0,         0,
2708     0,          1.33992,    0,         2.9739,    0,         0,
2709     1.31914,    2.66147,    0,         0,
2710 
2711     0.942568,   0,          0,         0,         0.025507,  0,
2712     0,          0,          0.321429,  0.569141,  1.25274,   1.57719,
2713     0.8158,     1.21805,    0.586239,  0.25427,
2714 
2715     1.04436,    0,          0.630725,  0,         0.133801,  0.210693,
2716     0.363026,   0,          0.533426,  0,         1.25926,   0.722707,
2717     0,          1.22031,    1.30117,   0.495867,
2718 
2719     0.222187,   0,          0.72725,   0,         0.767003,  0,
2720     0,          0.147835,   0,         0,         0,         0.608758,
2721     0.469394,   0.00720298, 0.927537,  0,
2722 
2723     0.856974,   0.424257,   0,         0,         0.937329,  0,
2724     0,          0,          0.476425,  0,         0.566017,  0.418462,
2725     0.141911,   0.996214,   1.13063,   0,
2726 
2727     0.967899,   0,          0,         0,         0.0831304, 0,
2728     0,          1.00378,    0,         0,         0,         1.44818,
2729     1.01768,    0.943891,   0.502745,  0,
2730 
2731     0.940135,   0,          0,         0,         0,         0,
2732     0,          2.13243,    0,         0.71208,   0.123918,  1.53907,
2733     1.30225,    1.59644,    0.70222,   0,
2734 
2735     0.804329,   0,          0.430576,  0,         0.505872,  0.509603,
2736     0.343448,   0,          0.107756,  0.614544,  1.44549,   1.52311,
2737     0.0454298,  0.300267,   0.562784,  0.395095,
2738 
2739     0.228154,   0,          0.675323,  0,         1.70536,   0.766217,
2740     0,          0,          0,         0.735363,  0.0759267, 1.91017,
2741     0.941888,   0,          0,         0,
2742 
2743     0,          0,          1.5909,    0,         0,         0,
2744     0,          0.5755,     0,         0.184687,  0,         1.56296,
2745     0.625285,   0,          0,         0,
2746 
2747     0,          0,          0.0857888, 0,         0,         0,
2748     0,          0.488383,   0.252786,  0,         0,         0,
2749     1.02817,    1.85665,    0,         0,
2750 
2751     0.00981836, 0,          1.06371,   0,         0,         0,
2752     0,          0,          0,         0.290445,  0.316406,  0,
2753     0.304161,   1.25079,    0.0707152, 0,
2754 
2755     0.986264,   0.309201,   0,         0,         0,         0,
2756     0,          1.64896,    0.346248,  0,         0.918175,  0.78884,
2757     0.524981,   1.92076,    2.07013,   0.333244,
2758 
2759     0.415153,   0.210318,   0,         0,         0,         0,
2760     0,          2.02616,    0,         0.728256,  0.84183,   0.0907453,
2761     0.628881,   3.58099,    1.49974,   0};
2762 
2763 static std::initializer_list<float> rnn_weights = {
2764     0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
2765     0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
2766     0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
2767     -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
2768     -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
2769     -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
2770     -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
2771     0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
2772     0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
2773     0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
2774     -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
2775     0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
2776     -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
2777     -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
2778     0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
2779     0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
2780     0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
2781     -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
2782     0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
2783     0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
2784     -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
2785     0.277308,    0.415818};
2786 
2787 static std::initializer_list<float> rnn_recurrent_weights = {
2788     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2789     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2790     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2791     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2792     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2793     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2794     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2795     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2796     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2797     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2798     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2799     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2800     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2801     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2802     0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2803     0.1};
2804 
2805 static std::initializer_list<float> rnn_bias = {
2806     0.065691948, -0.69055247, 0.1107955,  -0.97084129, -0.23957068, -0.23566568,
2807     -0.389184,   0.47481549,  -0.4791103, 0.29931796,  0.10463274,  0.83918178,
2808     0.37197268,  0.61957061,  0.3956964,  -0.37609905};
2809 
2810 class RNNOpModel : public SingleOpModelWithNNAPI {
2811  public:
RNNOpModel(int batches,int units,int size,const TensorType weights=TensorType_FLOAT32,const TensorType recurrent_weights=TensorType_FLOAT32)2812   RNNOpModel(int batches, int units, int size,
2813              const TensorType weights = TensorType_FLOAT32,
2814              const TensorType recurrent_weights = TensorType_FLOAT32)
2815       : batches_(batches), units_(units), input_size_(size) {
2816     input_ = AddInput(TensorType_FLOAT32);
2817     weights_ = AddInput(weights);
2818     recurrent_weights_ = AddInput(recurrent_weights);
2819     bias_ = AddInput(TensorType_FLOAT32);
2820     hidden_state_ = AddVariableInput(TensorType_FLOAT32);
2821     output_ = AddOutput(TensorType_FLOAT32);
2822     SetBuiltinOp(
2823         BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
2824         CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
2825     BuildInterpreterWithNNAPI({
2826         {batches_, input_size_},  // input tensor
2827         {units_, input_size_},    // weights tensor
2828         {units_, units_},         // recurrent weights tensor
2829         {units_},                 // bias tensor
2830         {batches_, units_}        // hidden state tensor
2831     });
2832   }
2833 
SetBias(std::initializer_list<float> f)2834   void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
2835 
SetWeights(std::initializer_list<float> f)2836   void SetWeights(std::initializer_list<float> f) {
2837     PopulateTensor(weights_, f);
2838   }
2839 
SetRecurrentWeights(std::initializer_list<float> f)2840   void SetRecurrentWeights(std::initializer_list<float> f) {
2841     PopulateTensor(recurrent_weights_, f);
2842   }
2843 
SetInput(std::initializer_list<float> data)2844   void SetInput(std::initializer_list<float> data) {
2845     PopulateTensor(input_, data);
2846   }
2847 
SetInput(int offset,float * begin,float * end)2848   void SetInput(int offset, float* begin, float* end) {
2849     PopulateTensor(input_, offset, begin, end);
2850   }
2851 
GetOutput()2852   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
2853 
input_size()2854   int input_size() { return input_size_; }
num_units()2855   int num_units() { return units_; }
num_batches()2856   int num_batches() { return batches_; }
2857 
2858  protected:
2859   int input_;
2860   int weights_;
2861   int recurrent_weights_;
2862   int bias_;
2863   int hidden_state_;
2864   int output_;
2865 
2866   int batches_;
2867   int units_;
2868   int input_size_;
2869 };
2870 
TEST(NNAPIDelegate,RnnBlackBoxTest)2871 TEST(NNAPIDelegate, RnnBlackBoxTest) {
2872   RNNOpModel rnn(2, 16, 8);
2873   rnn.SetWeights(rnn_weights);
2874   rnn.SetBias(rnn_bias);
2875   rnn.SetRecurrentWeights(rnn_recurrent_weights);
2876 
2877   const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
2878                                   (rnn.input_size() * rnn.num_batches());
2879 
2880   for (int i = 0; i < input_sequence_size; i++) {
2881     float* batch_start = rnn_input + i * rnn.input_size();
2882     float* batch_end = batch_start + rnn.input_size();
2883     rnn.SetInput(0, batch_start, batch_end);
2884     rnn.SetInput(rnn.input_size(), batch_start, batch_end);
2885 
2886     ASSERT_EQ(rnn.Invoke(), kTfLiteOk);
2887 
2888     float* golden_start = rnn_golden_output + i * rnn.num_units();
2889     float* golden_end = golden_start + rnn.num_units();
2890     std::vector<float> expected;
2891     expected.insert(expected.end(), golden_start, golden_end);
2892     expected.insert(expected.end(), golden_start, golden_end);
2893 
2894     EXPECT_THAT(rnn.GetOutput(), NnapiArrayFloatNear(expected));
2895   }
2896 }
2897 
2898 static float svdf_input[] = {
2899     0.12609188,  -0.46347019, -0.89598465,
2900     0.35867718,  0.36897406,  0.73463392,
2901 
2902     0.14278367,  -1.64410412, -0.75222826,
2903     -0.57290924, 0.12729003,  0.7567004,
2904 
2905     0.49837467,  0.19278903,  0.26584083,
2906     0.17660543,  0.52949083,  -0.77931279,
2907 
2908     -0.11186574, 0.13164264,  -0.05349274,
2909     -0.72674477, -0.5683046,  0.55900657,
2910 
2911     -0.68892461, 0.37783599,  0.18263303,
2912     -0.63690937, 0.44483393,  -0.71817774,
2913 
2914     -0.81299269, -0.86831826, 1.43940818,
2915     -0.95760226, 1.82078898,  0.71135032,
2916 
2917     -1.45006323, -0.82251364, -1.69082689,
2918     -1.65087092, -1.89238167, 1.54172635,
2919 
2920     0.03966608,  -0.24936394, -0.77526885,
2921     2.06740379,  -1.51439476, 1.43768692,
2922 
2923     0.11771342,  -0.23761693, -0.65898693,
2924     0.31088525,  -1.55601168, -0.87661445,
2925 
2926     -0.89477462, 1.67204106,  -0.53235275,
2927     -0.6230064,  0.29819036,  1.06939757,
2928 };
2929 
2930 static float svdf_golden_output_rank_1[] = {
2931     0.014899,    -0.0517661,  -0.143725,   -0.00271883,
2932     -0.03004015, 0.09565311,  0.1587342,   0.00784263,
2933 
2934     0.068281,    -0.162217,   -0.152268,   0.00323521,
2935     0.01582633,  0.03858774,  -0.03001583, -0.02671271,
2936 
2937     -0.0317821,  -0.0333089,  0.0609602,   0.0333759,
2938     -0.01432795, 0.05524484,  0.1101355,   -0.02382665,
2939 
2940     -0.00623099, -0.077701,   -0.391193,   -0.0136691,
2941     -0.02333033, 0.02293761,  0.12338032,  0.04326871,
2942 
2943     0.201551,    -0.164607,   -0.179462,   -0.0592739,
2944     0.01064911,  -0.17503069, 0.07821996,  -0.00224009,
2945 
2946     0.0886511,   -0.0875401,  -0.269283,   0.0281379,
2947     -0.02282338, 0.09741908,  0.32973239,  0.12281385,
2948 
2949     -0.201174,   -0.586145,   -0.628624,   -0.0330412,
2950     0.24780814,  -0.39304617, -0.22473189, 0.02589256,
2951 
2952     -0.0839096,  -0.299329,   0.108746,    0.109808,
2953     0.10084175,  -0.06416984, 0.28936723,  0.0026358,
2954 
2955     0.419114,    -0.237824,   -0.422627,   0.175115,
2956     -0.2314795,  -0.18584411, -0.4228974,  -0.12928449,
2957 
2958     0.36726,     -0.522303,   -0.456502,   -0.175475,
2959     0.17012937,  -0.34447709, 0.38505614,  -0.28158101,
2960 };
2961 
2962 static float svdf_golden_output_rank_2[] = {
2963     -0.09623547, -0.10193135, 0.11083051,  -0.0347917,
2964     0.1141196,   0.12965347,  -0.12652366, 0.01007236,
2965 
2966     -0.16396809, -0.21247184, 0.11259045,  -0.04156673,
2967     0.10132131,  -0.06143532, -0.00924693, 0.10084561,
2968 
2969     0.01257364,  0.0506071,   -0.19287863, -0.07162561,
2970     -0.02033747, 0.22673416,  0.15487903,  0.02525555,
2971 
2972     -0.1411963,  -0.37054959, 0.01774767,  0.05867489,
2973     0.09607603,  -0.0141301,  -0.08995658, 0.12867066,
2974 
2975     -0.27142537, -0.16955489, 0.18521598,  -0.12528358,
2976     0.00331409,  0.11167502,  0.02218599,  -0.07309391,
2977 
2978     0.09593632,  -0.28361851, -0.0773851,  0.17199151,
2979     -0.00075242, 0.33691186,  -0.1536046,  0.16572715,
2980 
2981     -0.27916506, -0.27626723, 0.42615682,  0.3225764,
2982     -0.37472126, -0.55655634, -0.05013514, 0.289112,
2983 
2984     -0.24418658, 0.07540751,  -0.1940318,  -0.08911639,
2985     0.00732617,  0.46737891,  0.26449674,  0.24888524,
2986 
2987     -0.17225097, -0.54660404, -0.38795233, 0.08389944,
2988     0.07736043,  -0.28260678, 0.15666828,  1.14949894,
2989 
2990     -0.57454878, -0.64704704, 0.73235172,  -0.34616736,
2991     0.21120001,  -0.22927976, 0.02455296,  -0.35906726,
2992 };
2993 
2994 class BaseSVDFOpModel : public SingleOpModelWithNNAPI {
2995  public:
BaseSVDFOpModel(int batches,int units,int input_size,int memory_size,int rank,TensorType weights_feature_type=TensorType_FLOAT32,TensorType weights_time_type=TensorType_FLOAT32)2996   BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
2997                   int rank,
2998                   TensorType weights_feature_type = TensorType_FLOAT32,
2999                   TensorType weights_time_type = TensorType_FLOAT32)
3000       : batches_(batches),
3001         units_(units),
3002         input_size_(input_size),
3003         memory_size_(memory_size),
3004         rank_(rank) {
3005     input_ = AddInput(TensorType_FLOAT32);
3006     weights_feature_ = AddInput(weights_feature_type);
3007     weights_time_ = AddInput(weights_time_type);
3008     // TODO(b/121383394) : figure out why optional bias causes TFLite segfault
3009     // when using NNAPI delegate.
3010     bias_ = AddInput(TensorType_FLOAT32);
3011     const int num_filters = units * rank;
3012     activation_state_ = AddVariableInput(
3013         TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}});
3014     output_ = AddOutput(TensorType_FLOAT32);
3015     SetBuiltinOp(
3016         BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
3017         CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
3018     BuildInterpreterWithNNAPI({
3019         {batches_, input_size_},              // input tensor
3020         {units_ * rank, input_size_},         // weights_feature tensor
3021         {units_ * rank, memory_size_},        // weights_time tensor
3022         {units_},                             // bias tensor
3023         {batches, memory_size * num_filters}  // activation_state tensor
3024     });
3025     // TODO(b/121383394) : remove once the optional bias bug is fixed.
3026     PopulateTensor(bias_, std::vector<float>(units_));
3027   }
3028 
3029   // Populates the weights_feature tensor.
SetWeightsFeature(std::initializer_list<float> f)3030   void SetWeightsFeature(std::initializer_list<float> f) {
3031     PopulateTensor(weights_feature_, f);
3032   }
3033 
3034   // Populates the weights_time tensor.
SetWeightsTime(std::initializer_list<float> f)3035   void SetWeightsTime(std::initializer_list<float> f) {
3036     PopulateTensor(weights_time_, f);
3037   }
3038 
3039   // Populates the input tensor.
SetInput(int offset,float * begin,float * end)3040   void SetInput(int offset, float* begin, float* end) {
3041     PopulateTensor(input_, offset, begin, end);
3042   }
3043 
3044   // Extracts the output tensor from the SVDF op.
GetOutput()3045   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
3046 
input_size()3047   int input_size() { return input_size_; }
num_units()3048   int num_units() { return units_; }
num_batches()3049   int num_batches() { return batches_; }
3050 
3051  protected:
3052   int input_;
3053   int weights_feature_;
3054   int weights_time_;
3055   int bias_;
3056   int activation_state_;
3057   int output_;
3058 
3059   int batches_;
3060   int units_;
3061   int input_size_;
3062   int memory_size_;
3063   int rank_;
3064 };
3065 
3066 class SVDFOpModel : public BaseSVDFOpModel {
3067  public:
3068   using BaseSVDFOpModel::BaseSVDFOpModel;
3069 };
3070 
3071 class SVDFOpTest : public ::testing::Test {
3072  protected:
VerifyGoldens(float golden_input[],float golden_output[],int golden_size,BaseSVDFOpModel * svdf,float tolerance=1e-5)3073   void VerifyGoldens(float golden_input[], float golden_output[],
3074                      int golden_size, BaseSVDFOpModel* svdf,
3075                      float tolerance = 1e-5) {
3076     const int svdf_num_batches = svdf->num_batches();
3077     const int svdf_input_size = svdf->input_size();
3078     const int svdf_num_units = svdf->num_units();
3079     const int input_sequence_size =
3080         golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches);
3081     // Going over each input batch, setting the input tensor, invoking the SVDF
3082     // op and checking the output with the expected golden values.
3083     for (int i = 0; i < input_sequence_size; i++) {
3084       float* batch_start =
3085           golden_input + i * svdf_input_size * svdf_num_batches;
3086       float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
3087       svdf->SetInput(0, batch_start, batch_end);
3088 
3089       ASSERT_EQ(svdf->Invoke(), kTfLiteOk);
3090 
3091       const float* golden_start =
3092           golden_output + i * svdf_num_units * svdf_num_batches;
3093       const float* golden_end =
3094           golden_start + svdf_num_units * svdf_num_batches;
3095       std::vector<float> expected;
3096       expected.insert(expected.end(), golden_start, golden_end);
3097 
3098       EXPECT_THAT(svdf->GetOutput(),
3099                   ElementsAreArray(ArrayFloatNear(expected, tolerance)));
3100     }
3101   }
3102 };
3103 
TEST_F(SVDFOpTest,BlackBoxTestRank1)3104 TEST_F(SVDFOpTest, BlackBoxTestRank1) {
3105   SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
3106                    /*memory_size=*/10, /*rank=*/1);
3107   svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
3108                           0.22197971, 0.12416199, 0.27901134, 0.27557442,
3109                           0.3905206, -0.36137494, -0.06634006, -0.10640851});
3110 
3111   svdf.SetWeightsTime(
3112       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
3113        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
3114 
3115        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
3116        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
3117 
3118        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
3119        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
3120 
3121        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
3122        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657});
3123 
3124   VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
3125                 &svdf);
3126 }
3127 
TEST_F(SVDFOpTest,BlackBoxTestRank2)3128 TEST_F(SVDFOpTest, BlackBoxTestRank2) {
3129   SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
3130                    /*memory_size=*/10, /*rank=*/2);
3131   svdf.SetWeightsFeature({-0.31930989, 0.0079667,   0.39296314,  0.37613347,
3132                           0.12416199,  0.15785322,  0.27901134,  0.3905206,
3133                           0.21931258,  -0.36137494, -0.10640851, 0.31053296,
3134                           -0.36118156, -0.0976817,  -0.36916667, 0.22197971,
3135                           0.15294972,  0.38031587,  0.27557442,  0.39635518,
3136                           -0.21580373, -0.06634006, -0.02702999, 0.27072677});
3137 
3138   svdf.SetWeightsTime(
3139       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
3140        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
3141 
3142        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
3143        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
3144 
3145        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
3146        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
3147 
3148        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
3149        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,
3150 
3151        -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
3152        0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,
3153 
3154        -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
3155        0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,
3156 
3157        -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
3158        -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,
3159 
3160        0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
3161        0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763});
3162 
3163   VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
3164                 &svdf);
3165 }
3166 
3167 class LSTMOpModel : public SingleOpModelWithNNAPI {
3168  public:
LSTMOpModel(int n_batch,int n_input,int n_cell,int n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,float cell_clip,float proj_clip,const std::vector<std::vector<int>> & input_shapes,const TensorType weight_type)3169   LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
3170               bool use_peephole, bool use_projection_weights,
3171               bool use_projection_bias, float cell_clip, float proj_clip,
3172               const std::vector<std::vector<int>>& input_shapes,
3173               const TensorType weight_type)
3174       : n_batch_(n_batch),
3175         n_input_(n_input),
3176         n_cell_(n_cell),
3177         n_output_(n_output),
3178         weight_type_(weight_type) {
3179     input_ = AddInput(TensorType_FLOAT32);
3180 
3181     if (use_cifg) {
3182       input_to_input_weights_ = AddNullInput();
3183     } else {
3184       input_to_input_weights_ = AddInput(weight_type);
3185     }
3186 
3187     input_to_forget_weights_ = AddInput(weight_type);
3188     input_to_cell_weights_ = AddInput(weight_type);
3189     input_to_output_weights_ = AddInput(weight_type);
3190 
3191     if (use_cifg) {
3192       recurrent_to_input_weights_ = AddNullInput();
3193     } else {
3194       recurrent_to_input_weights_ = AddInput(weight_type);
3195     }
3196 
3197     recurrent_to_forget_weights_ = AddInput(weight_type);
3198     recurrent_to_cell_weights_ = AddInput(weight_type);
3199     recurrent_to_output_weights_ = AddInput(weight_type);
3200 
3201     if (use_peephole) {
3202       if (use_cifg) {
3203         cell_to_input_weights_ = AddNullInput();
3204       } else {
3205         cell_to_input_weights_ = AddInput(weight_type);
3206       }
3207       cell_to_forget_weights_ = AddInput(weight_type);
3208       cell_to_output_weights_ = AddInput(weight_type);
3209     } else {
3210       cell_to_input_weights_ = AddNullInput();
3211       cell_to_forget_weights_ = AddNullInput();
3212       cell_to_output_weights_ = AddNullInput();
3213     }
3214 
3215     if (use_cifg) {
3216       input_gate_bias_ = AddNullInput();
3217     } else {
3218       input_gate_bias_ = AddInput(TensorType_FLOAT32);
3219     }
3220     forget_gate_bias_ = AddInput(TensorType_FLOAT32);
3221     cell_bias_ = AddInput(TensorType_FLOAT32);
3222     output_gate_bias_ = AddInput(TensorType_FLOAT32);
3223 
3224     if (use_projection_weights) {
3225       projection_weights_ = AddInput(weight_type);
3226       if (use_projection_bias) {
3227         projection_bias_ = AddInput(TensorType_FLOAT32);
3228       } else {
3229         projection_bias_ = AddNullInput();
3230       }
3231     } else {
3232       projection_weights_ = AddNullInput();
3233       projection_bias_ = AddNullInput();
3234     }
3235 
3236     // Adding the 2 input state tensors.
3237     input_activation_state_ = AddVariableInput(TensorType_FLOAT32);
3238     input_cell_state_ = AddVariableInput(TensorType_FLOAT32);
3239 
3240     const bool use_layer_norm = input_shapes.size() > 20;
3241     // Layer norm weights.
3242     if (use_layer_norm) {
3243       const int kInputLayerNormCoeffsIndex = 20;
3244       const int kForgetLayerNormCoeffsIndex = 21;
3245       const int kCellLayerNormCoeffsIndex = 22;
3246       const int kOutputLayerNormCoeffsIndex = 23;
3247 
3248       if (use_cifg) {
3249         input_layer_norm_coefficients_ = AddNullInput();
3250       } else {
3251         input_layer_norm_coefficients_ =
3252             AddLayerNormCoeffsTensor(kInputLayerNormCoeffsIndex, input_shapes);
3253       }
3254       forget_layer_norm_coefficients_ =
3255           AddLayerNormCoeffsTensor(kForgetLayerNormCoeffsIndex, input_shapes);
3256       cell_layer_norm_coefficients_ =
3257           AddLayerNormCoeffsTensor(kCellLayerNormCoeffsIndex, input_shapes);
3258       output_layer_norm_coefficients_ =
3259           AddLayerNormCoeffsTensor(kOutputLayerNormCoeffsIndex, input_shapes);
3260     }
3261 
3262     output_ = AddOutput(TensorType_FLOAT32);
3263 
3264     SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
3265                  CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
3266                                    cell_clip, proj_clip)
3267                      .Union());
3268     BuildInterpreterWithNNAPI(input_shapes);
3269   }
3270 
SetInputToInputWeights(const std::vector<float> & f)3271   void SetInputToInputWeights(const std::vector<float>& f) {
3272     SetData(input_to_input_weights_, weight_type_, f);
3273   }
3274 
SetInputToForgetWeights(const std::vector<float> & f)3275   void SetInputToForgetWeights(const std::vector<float>& f) {
3276     SetData(input_to_forget_weights_, weight_type_, f);
3277   }
3278 
SetInputToCellWeights(const std::vector<float> & f)3279   void SetInputToCellWeights(const std::vector<float>& f) {
3280     SetData(input_to_cell_weights_, weight_type_, f);
3281   }
3282 
SetInputToOutputWeights(const std::vector<float> & f)3283   void SetInputToOutputWeights(const std::vector<float>& f) {
3284     SetData(input_to_output_weights_, weight_type_, f);
3285   }
3286 
SetRecurrentToInputWeights(const std::vector<float> & f)3287   void SetRecurrentToInputWeights(const std::vector<float>& f) {
3288     SetData(recurrent_to_input_weights_, weight_type_, f);
3289   }
3290 
SetRecurrentToForgetWeights(const std::vector<float> & f)3291   void SetRecurrentToForgetWeights(const std::vector<float>& f) {
3292     SetData(recurrent_to_forget_weights_, weight_type_, f);
3293   }
3294 
SetRecurrentToCellWeights(const std::vector<float> & f)3295   void SetRecurrentToCellWeights(const std::vector<float>& f) {
3296     SetData(recurrent_to_cell_weights_, weight_type_, f);
3297   }
3298 
SetRecurrentToOutputWeights(const std::vector<float> & f)3299   void SetRecurrentToOutputWeights(const std::vector<float>& f) {
3300     SetData(recurrent_to_output_weights_, weight_type_, f);
3301   }
3302 
SetCellToInputWeights(const std::vector<float> & f)3303   void SetCellToInputWeights(const std::vector<float>& f) {
3304     SetData(cell_to_input_weights_, weight_type_, f);
3305   }
3306 
SetCellToForgetWeights(const std::vector<float> & f)3307   void SetCellToForgetWeights(const std::vector<float>& f) {
3308     SetData(cell_to_forget_weights_, weight_type_, f);
3309   }
3310 
SetCellToOutputWeights(const std::vector<float> & f)3311   void SetCellToOutputWeights(const std::vector<float>& f) {
3312     SetData(cell_to_output_weights_, weight_type_, f);
3313   }
3314 
SetInputGateBias(const std::vector<float> & f)3315   void SetInputGateBias(const std::vector<float>& f) {
3316     PopulateTensor(input_gate_bias_, f);
3317   }
3318 
SetForgetGateBias(const std::vector<float> & f)3319   void SetForgetGateBias(const std::vector<float>& f) {
3320     PopulateTensor(forget_gate_bias_, f);
3321   }
3322 
SetCellBias(const std::vector<float> & f)3323   void SetCellBias(const std::vector<float>& f) {
3324     PopulateTensor(cell_bias_, f);
3325   }
3326 
SetOutputGateBias(const std::vector<float> & f)3327   void SetOutputGateBias(const std::vector<float>& f) {
3328     PopulateTensor(output_gate_bias_, f);
3329   }
3330 
SetProjectionWeights(const std::vector<float> & f)3331   void SetProjectionWeights(const std::vector<float>& f) {
3332     SetData(projection_weights_, weight_type_, f);
3333   }
3334 
SetProjectionBias(const std::vector<float> & f)3335   void SetProjectionBias(const std::vector<float>& f) {
3336     PopulateTensor(projection_bias_, f);
3337   }
3338 
SetInputLayerNormCoefficients(const std::vector<float> & f)3339   void SetInputLayerNormCoefficients(const std::vector<float>& f) {
3340     PopulateTensor(input_layer_norm_coefficients_, f);
3341   }
3342 
SetForgetLayerNormCoefficients(const std::vector<float> & f)3343   void SetForgetLayerNormCoefficients(const std::vector<float>& f) {
3344     PopulateTensor(forget_layer_norm_coefficients_, f);
3345   }
3346 
SetCellLayerNormCoefficients(const std::vector<float> & f)3347   void SetCellLayerNormCoefficients(const std::vector<float>& f) {
3348     PopulateTensor(cell_layer_norm_coefficients_, f);
3349   }
3350 
SetOutputLayerNormCoefficients(const std::vector<float> & f)3351   void SetOutputLayerNormCoefficients(const std::vector<float>& f) {
3352     PopulateTensor(output_layer_norm_coefficients_, f);
3353   }
3354 
SetInput(int offset,const float * begin,const float * end)3355   void SetInput(int offset, const float* begin, const float* end) {
3356     PopulateTensor(input_, offset, const_cast<float*>(begin),
3357                    const_cast<float*>(end));
3358   }
3359 
GetOutput()3360   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
3361 
num_inputs()3362   int num_inputs() { return n_input_; }
num_outputs()3363   int num_outputs() { return n_output_; }
num_cells()3364   int num_cells() { return n_cell_; }
num_batches()3365   int num_batches() { return n_batch_; }
3366 
3367  protected:
3368   int input_;
3369   int input_to_input_weights_;
3370   int input_to_forget_weights_;
3371   int input_to_cell_weights_;
3372   int input_to_output_weights_;
3373 
3374   int recurrent_to_input_weights_;
3375   int recurrent_to_forget_weights_;
3376   int recurrent_to_cell_weights_;
3377   int recurrent_to_output_weights_;
3378 
3379   int cell_to_input_weights_;
3380   int cell_to_forget_weights_;
3381   int cell_to_output_weights_;
3382 
3383   int input_gate_bias_;
3384   int forget_gate_bias_;
3385   int cell_bias_;
3386   int output_gate_bias_;
3387 
3388   int projection_weights_;
3389   int projection_bias_;
3390   int input_activation_state_;
3391   int input_cell_state_;
3392 
3393   int input_layer_norm_coefficients_;
3394   int forget_layer_norm_coefficients_;
3395   int cell_layer_norm_coefficients_;
3396   int output_layer_norm_coefficients_;
3397 
3398   int output_;
3399   int output_state_;
3400   int cell_state_;
3401 
3402   int n_batch_;
3403   int n_input_;
3404   int n_cell_;
3405   int n_output_;
3406 
3407  private:
3408   const TensorType weight_type_;
3409 
AddLayerNormCoeffsTensor(int tensor_index,const std::vector<std::vector<int>> & input_shapes)3410   int AddLayerNormCoeffsTensor(
3411       int tensor_index, const std::vector<std::vector<int>>& input_shapes) {
3412     if (input_shapes[tensor_index][0] != 0) {
3413       return AddInput(TensorType_FLOAT32);
3414     } else {
3415       return AddNullInput();
3416     }
3417   }
3418 };
3419 
3420 class BaseLstmTest : public ::testing::Test {
3421  protected:
3422   // Weights of the LSTM model. Some are optional.
3423   std::vector<float> input_to_input_weights_;
3424   std::vector<float> input_to_cell_weights_;
3425   std::vector<float> input_to_forget_weights_;
3426   std::vector<float> input_to_output_weights_;
3427   std::vector<float> input_gate_bias_;
3428   std::vector<float> cell_gate_bias_;
3429   std::vector<float> forget_gate_bias_;
3430   std::vector<float> output_gate_bias_;
3431   std::vector<float> recurrent_to_input_weights_;
3432   std::vector<float> recurrent_to_cell_weights_;
3433   std::vector<float> recurrent_to_forget_weights_;
3434   std::vector<float> recurrent_to_output_weights_;
3435   std::vector<float> cell_to_input_weights_;
3436   std::vector<float> cell_to_forget_weights_;
3437   std::vector<float> cell_to_output_weights_;
3438   std::vector<float> projection_weights_;
3439   std::vector<float> input_layer_norm_coefficients_;
3440   std::vector<float> forget_layer_norm_coefficients_;
3441   std::vector<float> cell_layer_norm_coefficients_;
3442   std::vector<float> output_layer_norm_coefficients_;
3443 
3444   // LSTM input is stored as num_batch x num_inputs vector.
3445   std::vector<std::vector<float>> lstm_input_;
3446   // LSTM output is stored as num_batch x num_outputs vector.
3447   std::vector<std::vector<float>> lstm_golden_output_;
3448 
3449   // Compares output up to tolerance to the result of the lstm given the input.
VerifyGoldens(const std::vector<std::vector<float>> & input,const std::vector<std::vector<float>> & output,LSTMOpModel * lstm,float tolerance=1e-5)3450   void VerifyGoldens(const std::vector<std::vector<float>>& input,
3451                      const std::vector<std::vector<float>>& output,
3452                      LSTMOpModel* lstm, float tolerance = 1e-5) {
3453     const int num_batches = input.size();
3454     EXPECT_GT(num_batches, 0);
3455     const int num_inputs = lstm->num_inputs();
3456     EXPECT_GT(num_inputs, 0);
3457     const int input_sequence_size = input[0].size() / num_inputs;
3458     EXPECT_GT(input_sequence_size, 0);
3459     for (int i = 0; i < input_sequence_size; ++i) {
3460       for (int b = 0; b < num_batches; ++b) {
3461         const float* batch_start = input[b].data() + i * num_inputs;
3462         const float* batch_end = batch_start + num_inputs;
3463 
3464         lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end);
3465       }
3466 
3467       ASSERT_EQ(lstm->Invoke(), kTfLiteOk);
3468 
3469       const int num_outputs = lstm->num_outputs();
3470       std::vector<float> expected;
3471       for (int b = 0; b < num_batches; ++b) {
3472         const float* golden_start_batch = output[b].data() + i * num_outputs;
3473         const float* golden_end_batch = golden_start_batch + num_outputs;
3474         expected.insert(expected.end(), golden_start_batch, golden_end_batch);
3475       }
3476       EXPECT_THAT(lstm->GetOutput(),
3477                   ElementsAreArray(ArrayFloatNear(expected, tolerance)));
3478     }
3479   }
3480 };
3481 
3482 class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
SetUp()3483   void SetUp() override {
3484     input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
3485                                -0.34550029, 0.04266912,  -0.15680569,
3486                                -0.34856534, 0.43890524};
3487     input_to_cell_weights_ = {-0.50013041, 0.1370284,  0.11810488, 0.2013163,
3488                               -0.20583314, 0.44344562, 0.22077113, -0.29909778};
3489     input_to_forget_weights_ = {0.09701663,  0.20334584,  -0.50592935,
3490                                 -0.31343272, -0.40032279, 0.44781327,
3491                                 0.01387155,  -0.35593212};
3492     input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829,
3493                                 0.40525138,  0.44272184,  0.03897077,
3494                                 -0.1556896,  0.19487578};
3495     input_gate_bias_ = {0., 0., 0., 0.};
3496     cell_gate_bias_ = {0., 0., 0., 0.};
3497     forget_gate_bias_ = {1., 1., 1., 1.};
3498     output_gate_bias_ = {0., 0., 0., 0.};
3499 
3500     recurrent_to_input_weights_ = {
3501         -0.0063535,  -0.2042388,  0.31454784,  -0.35746509,
3502         0.28902304,  0.08183324,  -0.16555229, 0.02286911,
3503         -0.13566875, 0.03034258,  0.48091322,  -0.12528998,
3504         0.24077177,  -0.51332325, -0.33502164, 0.10629296};
3505 
3506     recurrent_to_cell_weights_ = {
3507         -0.3407414,  0.24443203,  -0.2078532,  0.26320225,
3508         0.05695659,  -0.00123841, -0.4744786,  -0.35869038,
3509         -0.06418842, -0.13502428, -0.501764,   0.22830659,
3510         -0.46367589, 0.26016325,  -0.03894562, -0.16368064};
3511 
3512     recurrent_to_forget_weights_ = {
3513         -0.48684245, -0.06655136, 0.42224967,  0.2112639,
3514         0.27654213,  0.20864892,  -0.07646349, 0.45877004,
3515         0.00141793,  -0.14609534, 0.36447752,  0.09196436,
3516         0.28053468,  0.01560611,  -0.20127171, -0.01140004};
3517 
3518     recurrent_to_output_weights_ = {
3519         0.43385774,  -0.17194885, 0.2718237,  0.09215671,
3520         0.24107647,  -0.39835793, 0.18212086, 0.01301402,
3521         0.48572797,  -0.50656658, 0.20047462, -0.20607421,
3522         -0.51818722, -0.15390486, 0.0468148,  0.39922136};
3523 
3524     lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
3525     lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765,
3526                             -0.03716109, 0.12507336, 0.41193449, -0.20860538,
3527                             -0.15053082, 0.09120187, 0.24278517, -0.12222792}};
3528   }
3529 };
3530 
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest,LstmBlackBoxTest)3531 TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
3532   const int n_batch = 1;
3533   const int n_input = 2;
3534   // n_cell and n_output have the same size when there is no projection.
3535   const int n_cell = 4;
3536   const int n_output = 4;
3537 
3538   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
3539                    /*use_cifg=*/false, /*use_peephole=*/false,
3540                    /*use_projection_weights=*/false,
3541                    /*use_projection_bias=*/false,
3542                    /*cell_clip=*/0.0, /*proj_clip=*/0.0,
3543                    {
3544                        {n_batch, n_input},  // input tensor
3545 
3546                        {n_cell, n_input},  // input_to_input_weight tensor
3547                        {n_cell, n_input},  // input_to_forget_weight tensor
3548                        {n_cell, n_input},  // input_to_cell_weight tensor
3549                        {n_cell, n_input},  // input_to_output_weight tensor
3550 
3551                        {n_cell, n_output},  // recurrent_to_input_weight_tensor
3552                        {n_cell, n_output},  // recurrent_to_forget_weight_tensor
3553                        {n_cell, n_output},  // recurrent_to_cell_weight_tensor
3554                        {n_cell, n_output},  // recurrent_to_output_weight_tensor
3555 
3556                        {0},  // cell_to_input_weight tensor
3557                        {0},  // cell_to_forget_weight tensor
3558                        {0},  // cell_to_output_weight tensor
3559 
3560                        {n_cell},  // input_gate_bias tensor
3561                        {n_cell},  // forget_gate_bias tensor
3562                        {n_cell},  // cell_bias tensor
3563                        {n_cell},  // output_gate_bias tensor
3564 
3565                        {0, 0},  // projection_weight tensor
3566                        {0},     // projection_bias tensor
3567 
3568                        {n_batch, n_output},  // activation_state tensor
3569                        {n_batch, n_cell},    // cell_state tensor
3570                    },
3571                    /*weight_type=*/TensorType_FLOAT32);
3572 
3573   lstm.SetInputToInputWeights(input_to_input_weights_);
3574   lstm.SetInputToCellWeights(input_to_cell_weights_);
3575   lstm.SetInputToForgetWeights(input_to_forget_weights_);
3576   lstm.SetInputToOutputWeights(input_to_output_weights_);
3577 
3578   lstm.SetInputGateBias(input_gate_bias_);
3579   lstm.SetCellBias(cell_gate_bias_);
3580   lstm.SetForgetGateBias(forget_gate_bias_);
3581   lstm.SetOutputGateBias(output_gate_bias_);
3582 
3583   lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
3584   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
3585   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
3586   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
3587 
3588   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
3589 }
3590 
3591 class NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest
3592     : public NoCifgNoPeepholeNoProjectionNoClippingLstmTest {};
3593 
TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,LstmBlackBoxTest)3594 TEST_F(NoCifgNoPeepholeNoProjectionNoClippingOmittedLayerNormLstmTest,
3595        LstmBlackBoxTest) {
3596   const int n_batch = 1;
3597   const int n_input = 2;
3598   // n_cell and n_output have the same size when there is no projection.
3599   const int n_cell = 4;
3600   const int n_output = 4;
3601 
3602   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
3603                    /*use_cifg=*/false, /*use_peephole=*/false,
3604                    /*use_projection_weights=*/false,
3605                    /*use_projection_bias=*/false,
3606                    /*cell_clip=*/0.0, /*proj_clip=*/0.0,
3607                    {
3608                        {n_batch, n_input},  // input tensor
3609 
3610                        {n_cell, n_input},  // input_to_input_weight tensor
3611                        {n_cell, n_input},  // input_to_forget_weight tensor
3612                        {n_cell, n_input},  // input_to_cell_weight tensor
3613                        {n_cell, n_input},  // input_to_output_weight tensor
3614 
3615                        {n_cell, n_output},  // recurrent_to_input_weight_tensor
3616                        {n_cell, n_output},  // recurrent_to_forget_weight_tensor
3617                        {n_cell, n_output},  // recurrent_to_cell_weight_tensor
3618                        {n_cell, n_output},  // recurrent_to_output_weight_tensor
3619 
3620                        {0},  // cell_to_input_weight tensor
3621                        {0},  // cell_to_forget_weight tensor
3622                        {0},  // cell_to_output_weight tensor
3623 
3624                        {n_cell},  // input_gate_bias tensor
3625                        {n_cell},  // forget_gate_bias tensor
3626                        {n_cell},  // cell_bias tensor
3627                        {n_cell},  // output_gate_bias tensor
3628 
3629                        {0, 0},  // projection_weight tensor
3630                        {0},     // projection_bias tensor
3631 
3632                        {n_batch, n_output},  // activation_state tensor
3633                        {n_batch, n_cell},    // cell_state tensor
3634 
3635                        {0},  // input_layer_norm_coefficient tensor
3636                        {0},  // forget_layer_norm_coefficient tensor
3637                        {0},  // cell_layer_norm_coefficient tensor
3638                        {0},  // output_layer_norm_coefficient tensor
3639                    },
3640                    /*weight_type=*/TensorType_FLOAT32);
3641 
3642   lstm.SetInputToInputWeights(input_to_input_weights_);
3643   lstm.SetInputToCellWeights(input_to_cell_weights_);
3644   lstm.SetInputToForgetWeights(input_to_forget_weights_);
3645   lstm.SetInputToOutputWeights(input_to_output_weights_);
3646 
3647   lstm.SetInputGateBias(input_gate_bias_);
3648   lstm.SetCellBias(cell_gate_bias_);
3649   lstm.SetForgetGateBias(forget_gate_bias_);
3650   lstm.SetOutputGateBias(output_gate_bias_);
3651 
3652   lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
3653   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
3654   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
3655   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
3656 
3657   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
3658 }
3659 
3660 class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
SetUp()3661   void SetUp() override {
3662     input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
3663                               0.05100781,  0.04717243,  0.48944736,
3664                               -0.38535351, -0.17212132};
3665 
3666     input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
3667                                 -0.3633365,  -0.22755712, 0.28253698,
3668                                 0.24407166,  0.33826375};
3669 
3670     input_to_output_weights_ = {0.10725588,  -0.02335852, -0.55932593,
3671                                 -0.09426838, -0.44257352, 0.54939759,
3672                                 0.01533556,  0.42751634};
3673     cell_gate_bias_ = {0., 0., 0., 0.};
3674     forget_gate_bias_ = {1., 1., 1., 1.};
3675     output_gate_bias_ = {0., 0., 0., 0.};
3676 
3677     recurrent_to_cell_weights_ = {
3678         0.54066205,  -0.32668582, -0.43562764, -0.56094903,
3679         0.42957711,  0.01841056,  -0.32764608, -0.33027974,
3680         -0.10826075, 0.20675004,  0.19069612,  -0.03026325,
3681         -0.54532051, 0.33003211,  0.44901288,  0.21193194};
3682 
3683     recurrent_to_forget_weights_ = {
3684         -0.13832897, -0.0515101,  -0.2359007, -0.16661474,
3685         -0.14340827, 0.36986142,  0.23414481, 0.55899,
3686         0.10798943,  -0.41174671, 0.17751795, -0.34484994,
3687         -0.35874045, -0.11352962, 0.27268326, 0.54058349};
3688 
3689     recurrent_to_output_weights_ = {
3690         0.41613156, 0.42610586,  -0.16495961, -0.5663873,
3691         0.30579174, -0.05115908, -0.33941799, 0.23364776,
3692         0.11178309, 0.09481031,  -0.26424935, 0.46261835,
3693         0.50248802, 0.26114327,  -0.43736315, 0.33149987};
3694 
3695     cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
3696                                0.31544167};
3697     cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
3698                                -0.77109635};
3699 
3700     lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
3701     lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646,
3702                             -0.42312205, -0.01218222, 0.24201041, -0.08124574,
3703                             -0.358325, -0.04621704, 0.21641694, -0.06471302}};
3704   }
3705 };
3706 
TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest,LstmBlackBoxTest)3707 TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
3708   const int n_batch = 1;
3709   const int n_input = 2;
3710   // n_cell and n_output have the same size when there is no projection.
3711   const int n_cell = 4;
3712   const int n_output = 4;
3713 
3714   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
3715                    /*use_cifg=*/true, /*use_peephole=*/true,
3716                    /*use_projection_weights=*/false,
3717                    /*use_projection_bias=*/false,
3718                    /*cell_clip=*/0.0, /*proj_clip=*/0.0,
3719                    {
3720                        {n_batch, n_input},  // input tensor
3721 
3722                        {0, 0},             // input_to_input_weight tensor
3723                        {n_cell, n_input},  // input_to_forget_weight tensor
3724                        {n_cell, n_input},  // input_to_cell_weight tensor
3725                        {n_cell, n_input},  // input_to_output_weight tensor
3726 
3727                        {0, 0},              // recurrent_to_input_weight tensor
3728                        {n_cell, n_output},  // recurrent_to_forget_weight tensor
3729                        {n_cell, n_output},  // recurrent_to_cell_weight tensor
3730                        {n_cell, n_output},  // recurrent_to_output_weight tensor
3731 
3732                        {0},       // cell_to_input_weight tensor
3733                        {n_cell},  // cell_to_forget_weight tensor
3734                        {n_cell},  // cell_to_output_weight tensor
3735 
3736                        {0},       // input_gate_bias tensor
3737                        {n_cell},  // forget_gate_bias tensor
3738                        {n_cell},  // cell_bias tensor
3739                        {n_cell},  // output_gate_bias tensor
3740 
3741                        {0, 0},  // projection_weight tensor
3742                        {0},     // projection_bias tensor
3743 
3744                        {n_batch, n_output},  // activation_state tensor
3745                        {n_batch, n_cell},    // cell_state tensor
3746                    },
3747                    /*weight_type=*/TensorType_FLOAT32);
3748 
3749   lstm.SetInputToCellWeights(input_to_cell_weights_);
3750   lstm.SetInputToForgetWeights(input_to_forget_weights_);
3751   lstm.SetInputToOutputWeights(input_to_output_weights_);
3752 
3753   lstm.SetCellBias(cell_gate_bias_);
3754   lstm.SetForgetGateBias(forget_gate_bias_);
3755   lstm.SetOutputGateBias(output_gate_bias_);
3756 
3757   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
3758   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
3759   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
3760 
3761   lstm.SetCellToForgetWeights(cell_to_forget_weights_);
3762   lstm.SetCellToOutputWeights(cell_to_output_weights_);
3763 
3764   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
3765 }
3766 
3767 class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
SetUp()3768   void SetUp() override {
3769     input_to_input_weights_ = {
3770         0.021393683,  0.06124551,    0.046905167,  -0.014657677,  -0.03149463,
3771         0.09171803,   0.14647801,    0.10797193,   -0.0057968358, 0.0019193048,
3772         -0.2726754,   0.10154029,    -0.018539885, 0.080349885,   -0.10262385,
3773         -0.022599787, -0.09121155,   -0.008675967, -0.045206103,  -0.0821282,
3774         -0.008045952, 0.015478081,   0.055217247,  0.038719587,   0.044153627,
3775         -0.06453243,  0.05031825,    -0.046935108, -0.008164439,  0.014574226,
3776         -0.1671009,   -0.15519552,   -0.16819797,  -0.13971269,   -0.11953059,
3777         0.25005487,   -0.22790983,   0.009855087,  -0.028140958,  -0.11200698,
3778         0.11295408,   -0.0035217577, 0.054485075,  0.05184695,    0.064711206,
3779         0.10989193,   0.11674786,    0.03490607,   0.07727357,    0.11390585,
3780         -0.1863375,   -0.1034451,    -0.13945189,  -0.049401227,  -0.18767063,
3781         0.042483903,  0.14233552,    0.13832581,   0.18350165,    0.14545603,
3782         -0.028545704, 0.024939531,   0.050929718,  0.0076203286,  -0.0029723682,
3783         -0.042484224, -0.11827596,   -0.09171104,  -0.10808628,   -0.16327988,
3784         -0.2273378,   -0.0993647,    -0.017155107, 0.0023917493,  0.049272764,
3785         0.0038534778, 0.054764505,   0.089753784,  0.06947234,    0.08014476,
3786         -0.04544234,  -0.0497073,    -0.07135631,  -0.048929106,  -0.004042012,
3787         -0.009284026, 0.018042054,   0.0036860977, -0.07427302,   -0.11434604,
3788         -0.018995456, 0.031487543,   0.012834908,  0.019977754,   0.044256654,
3789         -0.39292613,  -0.18519334,   -0.11651281,  -0.06809892,   0.011373677};
3790 
3791     input_to_forget_weights_ = {
3792         -0.0018401089, -0.004852237, 0.03698424,    0.014181704,
3793         0.028273236,   -0.016726194, -0.05249759,   -0.10204261,
3794         0.00861066,    -0.040979505, -0.009899187,  0.01923892,
3795         -0.028177269,  -0.08535103,  -0.14585495,   0.10662567,
3796         -0.01909731,   -0.017883534, -0.0047269356, -0.045103323,
3797         0.0030784295,  0.076784775,  0.07463696,    0.094531395,
3798         0.0814421,     -0.12257899,  -0.033945758,  -0.031303465,
3799         0.045630626,   0.06843887,   -0.13492945,   -0.012480007,
3800         -0.0811829,    -0.07224499,  -0.09628791,   0.045100946,
3801         0.0012300825,  0.013964662,  0.099372394,   0.02543059,
3802         0.06958324,    0.034257296,  0.0482646,     0.06267997,
3803         0.052625068,   0.12784666,   0.07077897,    0.025725935,
3804         0.04165009,    0.07241905,   0.018668644,   -0.037377294,
3805         -0.06277783,   -0.08833636,  -0.040120605,  -0.011405586,
3806         -0.007808335,  -0.010301386, -0.005102167,  0.027717464,
3807         0.05483423,    0.11449111,   0.11289652,    0.10939839,
3808         0.13396506,    -0.08402166,  -0.01901462,   -0.044678304,
3809         -0.07720565,   0.014350063,  -0.11757958,   -0.0652038,
3810         -0.08185733,   -0.076754324, -0.092614375,  0.10405491,
3811         0.052960336,   0.035755895,  0.035839386,   -0.012540553,
3812         0.036881298,   0.02913376,   0.03420159,    0.05448447,
3813         -0.054523353,  0.02582715,   0.02327355,    -0.011857179,
3814         -0.0011980024, -0.034641717, -0.026125094,  -0.17582615,
3815         -0.15923657,   -0.27486774,  -0.0006143371, 0.0001771948,
3816         -8.470171e-05, 0.02651807,   0.045790765,   0.06956496};
3817 
3818     input_to_cell_weights_ = {
3819         -0.04580283,   -0.09549462,   -0.032418985,  -0.06454633,
3820         -0.043528453,  0.043018587,   -0.049152344,  -0.12418144,
3821         -0.078985475,  -0.07596889,   0.019484362,   -0.11434962,
3822         -0.0074034138, -0.06314844,   -0.092981495,  0.0062155537,
3823         -0.025034338,  -0.0028890965, 0.048929527,   0.06235075,
3824         0.10665918,    -0.032036792,  -0.08505916,   -0.10843358,
3825         -0.13002433,   -0.036816437,  -0.02130134,   -0.016518239,
3826         0.0047691227,  -0.0025825808, 0.066017866,   0.029991534,
3827         -0.10652836,   -0.1037554,    -0.13056071,   -0.03266643,
3828         -0.033702414,  -0.006473424,  -0.04611692,   0.014419339,
3829         -0.025174323,  0.0396852,     0.081777506,   0.06157468,
3830         0.10210095,    -0.009658194,  0.046511717,   0.03603906,
3831         0.0069369148,  0.015960095,   -0.06507666,   0.09551598,
3832         0.053568836,   0.06408714,    0.12835667,    -0.008714329,
3833         -0.20211966,   -0.12093674,   0.029450472,   0.2849013,
3834         -0.029227901,  0.1164364,     -0.08560263,   0.09941786,
3835         -0.036999565,  -0.028842626,  -0.0033637602, -0.017012902,
3836         -0.09720865,   -0.11193351,   -0.029155117,  -0.017936034,
3837         -0.009768936,  -0.04223324,   -0.036159635,  0.06505112,
3838         -0.021742892,  -0.023377212,  -0.07221364,   -0.06430552,
3839         0.05453865,    0.091149814,   0.06387331,    0.007518393,
3840         0.055960953,   0.069779344,   0.046411168,   0.10509911,
3841         0.07463894,    0.0075130584,  0.012850982,   0.04555431,
3842         0.056955688,   0.06555285,    0.050801456,   -0.009862683,
3843         0.00826772,    -0.026555609,  -0.0073611983, -0.0014897042};
3844 
3845     input_to_output_weights_ = {
3846         -0.0998932,   -0.07201956,  -0.052803773,  -0.15629593,  -0.15001918,
3847         -0.07650751,  0.02359855,   -0.075155355,  -0.08037709,  -0.15093534,
3848         0.029517552,  -0.04751393,  0.010350531,   -0.02664851,  -0.016839722,
3849         -0.023121163, 0.0077019283, 0.012851257,   -0.05040649,  -0.0129761,
3850         -0.021737747, -0.038305793, -0.06870586,   -0.01481247,  -0.001285394,
3851         0.10124236,   0.083122835,  0.053313006,   -0.062235646, -0.075637154,
3852         -0.027833903, 0.029774971,  0.1130802,     0.09218906,   0.09506135,
3853         -0.086665764, -0.037162706, -0.038880914,  -0.035832845, -0.014481564,
3854         -0.09825003,  -0.12048569,  -0.097665586,  -0.05287633,  -0.0964047,
3855         -0.11366429,  0.035777505,  0.13568819,    0.052451383,  0.050649304,
3856         0.05798951,   -0.021852335, -0.099848844,  0.014740475,  -0.078897946,
3857         0.04974699,   0.014160473,  0.06973932,    0.04964942,   0.033364646,
3858         0.08190124,   0.025535367,  0.050893165,   0.048514254,  0.06945813,
3859         -0.078907564, -0.06707616,  -0.11844508,   -0.09986688,  -0.07509403,
3860         0.06263226,   0.14925587,   0.20188436,    0.12098451,   0.14639415,
3861         0.0015017595, -0.014267382, -0.03417257,   0.012711468,  0.0028300495,
3862         -0.024758482, -0.05098548,  -0.0821182,    0.014225672,  0.021544158,
3863         0.08949725,   0.07505268,   -0.0020780868, 0.04908258,   0.06476295,
3864         -0.022907063, 0.027562456,  0.040185735,   0.019567577,  -0.015598739,
3865         -0.049097303, -0.017121866, -0.083368234,  -0.02332002,  -0.0840956};
3866 
3867     input_gate_bias_ = {0.02234832,   0.14757581,  0.18176508,  0.10380666,
3868                         0.053110216,  -0.06928846, -0.13942584, -0.11816189,
3869                         0.19483899,   0.03652339,  -0.10250295, 0.036714908,
3870                         -0.18426876,  0.036065217, 0.21810818,  0.02383196,
3871                         -0.043370757, 0.08690144,  -0.04444982, 0.00030581196};
3872 
3873     forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
3874                          0.11098921,  0.15378423,   0.09263801,  0.09790885,
3875                          0.09508917,  0.061199076,  0.07665568,  -0.015443159,
3876                          -0.03499149, 0.046190713,  0.08895977,  0.10899629,
3877                          0.40694186,  0.06030037,   0.012413437, -0.06108739};
3878 
3879     cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132,   0.033463873,
3880                        -0.1483596,   -0.10639995,  -0.091433935, 0.058573797,
3881                        -0.06809782,  -0.07889636,  -0.043246906, -0.09829136,
3882                        -0.4279842,   0.034901652,  0.18797937,   0.0075234566,
3883                        0.016178843,  0.1749513,    0.13975595,   0.92058027};
3884 
3885     output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469,   0.12648113,
3886                          0.027195795, 0.35373217,    -0.018957434, 0.008907322,
3887                          -0.0762701,  0.12018895,    0.04216877,   0.0022856654,
3888                          0.040952638, 0.3147856,     0.08225149,   -0.057416286,
3889                          -0.14995944, -0.008040261,  0.13208859,   0.029760877};
3890 
3891     recurrent_to_input_weights_ = {
3892         -0.001374326,   -0.078856036,   0.10672688,    0.029162422,
3893         -0.11585556,    0.02557986,     -0.13446963,   -0.035785314,
3894         -0.01244275,    0.025961924,    -0.02337298,   -0.044228926,
3895         -0.055839065,   -0.046598054,   -0.010546039,  -0.06900766,
3896         0.027239809,    0.022582639,    -0.013296484,  -0.05459212,
3897         0.08981,        -0.045407712,   0.08682226,    -0.06867011,
3898         -0.14390695,    -0.02916037,    0.000996957,   0.091420636,
3899         0.14283475,     -0.07390571,    -0.06402044,   0.062524505,
3900         -0.093129106,   0.04860203,     -0.08364217,   -0.08119002,
3901         0.009352075,    0.22920375,     0.0016303885,  0.11583097,
3902         -0.13732095,    0.012405723,    -0.07551853,   0.06343048,
3903         0.12162708,     -0.031923793,   -0.014335606,  0.01790974,
3904         -0.10650317,    -0.0724401,     0.08554849,    -0.05727212,
3905         0.06556731,     -0.042729504,   -0.043227166,  0.011683251,
3906         -0.013082158,   -0.029302018,   -0.010899579,  -0.062036745,
3907         -0.022509435,   -0.00964907,    -0.01567329,   0.04260106,
3908         -0.07787477,    -0.11576462,    0.017356863,   0.048673786,
3909         -0.017577527,   -0.05527947,    -0.082487635,  -0.040137455,
3910         -0.10820036,    -0.04666372,    0.022746278,   -0.07851417,
3911         0.01068115,     0.032956902,    0.022433773,   0.0026891115,
3912         0.08944216,     -0.0685835,     0.010513544,   0.07228705,
3913         0.02032331,     -0.059686817,   -0.0005566496, -0.086984694,
3914         0.040414046,    -0.1380399,     0.094208956,   -0.05722982,
3915         0.012092817,    -0.04989123,    -0.086576,     -0.003399834,
3916         -0.04696032,    -0.045747425,   0.10091314,    0.048676282,
3917         -0.029037097,   0.031399418,    -0.0040285117, 0.047237843,
3918         0.09504992,     0.041799378,    -0.049185462,  -0.031518843,
3919         -0.10516937,    0.026374253,    0.10058866,    -0.0033195973,
3920         -0.041975245,   0.0073591834,   0.0033782164,  -0.004325073,
3921         -0.10167381,    0.042500053,    -0.01447153,   0.06464186,
3922         -0.017142897,   0.03312627,     0.009205989,   0.024138335,
3923         -0.011337001,   0.035530265,    -0.010912711,  0.0706555,
3924         -0.005894094,   0.051841937,    -0.1401738,    -0.02351249,
3925         0.0365468,      0.07590991,     0.08838724,    0.021681072,
3926         -0.10086113,    0.019608743,    -0.06195883,   0.077335775,
3927         0.023646897,    -0.095322326,   0.02233014,    0.09756986,
3928         -0.048691444,   -0.009579111,   0.07595467,    0.11480546,
3929         -0.09801813,    0.019894179,    0.08502348,    0.004032281,
3930         0.037211012,    0.068537936,    -0.048005626,  -0.091520436,
3931         -0.028379958,   -0.01556313,    0.06554592,    -0.045599163,
3932         -0.01672207,    -0.020169014,   -0.011877351,  -0.20212261,
3933         0.010889619,    0.0047078193,   0.038385306,   0.08540671,
3934         -0.017140968,   -0.0035865551,  0.016678626,   0.005633034,
3935         0.015963363,    0.00871737,     0.060130805,   0.028611384,
3936         0.10109069,     -0.015060172,   -0.07894427,   0.06401885,
3937         0.011584063,    -0.024466386,   0.0047652307,  -0.09041358,
3938         0.030737216,    -0.0046374933,  0.14215417,    -0.11823516,
3939         0.019899689,    0.006106124,    -0.027092824,  0.0786356,
3940         0.05052217,     -0.058925,      -0.011402121,  -0.024987547,
3941         -0.0013661642,  -0.06832946,    -0.015667673,  -0.1083353,
3942         -0.00096863037, -0.06988685,    -0.053350925,  -0.027275559,
3943         -0.033664223,   -0.07978348,    -0.025200296,  -0.017207067,
3944         -0.058403496,   -0.055697463,   0.005798788,   0.12965427,
3945         -0.062582195,   0.0013350133,   -0.10482091,   0.0379771,
3946         0.072521195,    -0.0029455067,  -0.13797039,   -0.03628521,
3947         0.013806405,    -0.017858358,   -0.01008298,   -0.07700066,
3948         -0.017081132,   0.019358726,    0.0027079724,  0.004635139,
3949         0.062634714,    -0.02338735,    -0.039547626,  -0.02050681,
3950         0.03385117,     -0.083611414,   0.002862572,   -0.09421313,
3951         0.058618143,    -0.08598433,    0.00972939,    0.023867095,
3952         -0.053934585,   -0.023203006,   0.07452513,    -0.048767887,
3953         -0.07314807,    -0.056307215,   -0.10433547,   -0.06440842,
3954         0.04328182,     0.04389765,     -0.020006588,  -0.09076438,
3955         -0.11652589,    -0.021705797,   0.03345259,    -0.010329105,
3956         -0.025767034,   0.013057034,    -0.07316461,   -0.10145612,
3957         0.06358255,     0.18531723,     0.07759293,    0.12006465,
3958         0.1305557,      0.058638252,    -0.03393652,   0.09622831,
3959         -0.16253184,    -2.4580743e-06, 0.079869635,   -0.070196845,
3960         -0.005644518,   0.06857898,     -0.12598175,   -0.035084512,
3961         0.03156317,     -0.12794146,    -0.031963028,  0.04692781,
3962         0.030070418,    0.0071660685,   -0.095516115,  -0.004643372,
3963         0.040170413,    -0.062104587,   -0.0037324072, 0.0554317,
3964         0.08184801,     -0.019164372,   0.06791302,    0.034257166,
3965         -0.10307039,    0.021943003,    0.046745934,   0.0790918,
3966         -0.0265588,     -0.007824208,   0.042546265,   -0.00977924,
3967         -0.0002440307,  -0.017384544,   -0.017990116,  0.12252321,
3968         -0.014512694,   -0.08251313,    0.08861942,    0.13589665,
3969         0.026351685,    0.012641483,    0.07466548,    0.044301085,
3970         -0.045414884,   -0.051112458,   0.03444247,    -0.08502782,
3971         -0.04106223,    -0.028126027,   0.028473156,   0.10467447};
3972 
3973     recurrent_to_cell_weights_ = {
3974         -0.037322544,   0.018592842,   0.0056175636,  -0.06253426,
3975         0.055647098,    -0.05713207,   -0.05626563,   0.005559383,
3976         0.03375411,     -0.025757805,  -0.088049285,  0.06017052,
3977         -0.06570978,    0.007384076,   0.035123326,   -0.07920549,
3978         0.053676967,    0.044480428,   -0.07663568,   0.0071805613,
3979         0.08089997,     0.05143358,    0.038261272,   0.03339287,
3980         -0.027673481,   0.044746667,   0.028349208,   0.020090483,
3981         -0.019443132,   -0.030755889,  -0.0040000007, 0.04465846,
3982         -0.021585021,   0.0031670958,  0.0053199246,  -0.056117613,
3983         -0.10893326,    0.076739706,   -0.08509834,   -0.027997585,
3984         0.037871376,    0.01449768,    -0.09002357,   -0.06111149,
3985         -0.046195522,   0.0422062,     -0.005683705,  -0.1253618,
3986         -0.012925729,   -0.04890792,   0.06985068,    0.037654128,
3987         0.03398274,     -0.004781977,  0.007032333,   -0.031787455,
3988         0.010868644,    -0.031489216,  0.09525667,    0.013939797,
3989         0.0058680447,   0.0167067,     0.02668468,    -0.04797466,
3990         -0.048885044,   -0.12722108,   0.035304096,   0.06554885,
3991         0.00972396,     -0.039238118,  -0.05159735,   -0.11329045,
3992         0.1613692,      -0.03750952,   0.06529313,    -0.071974665,
3993         -0.11769596,    0.015524369,   -0.0013754242, -0.12446318,
3994         0.02786344,     -0.014179351,  0.005264273,   0.14376344,
3995         0.015983658,    0.03406988,    -0.06939408,   0.040699873,
3996         0.02111075,     0.09669095,    0.041345075,   -0.08316494,
3997         -0.07684199,    -0.045768797,  0.032298047,   -0.041805092,
3998         0.0119405,      0.0061010392,  0.12652606,    0.0064572375,
3999         -0.024950314,   0.11574242,    0.04508852,    -0.04335324,
4000         0.06760663,     -0.027437469,  0.07216407,    0.06977076,
4001         -0.05438599,    0.034033038,   -0.028602652,  0.05346137,
4002         0.043184172,    -0.037189785,  0.10420091,    0.00882477,
4003         -0.054019816,   -0.074273005,  -0.030617684,  -0.0028467078,
4004         0.024302477,    -0.0038869337, 0.005332455,   0.0013399826,
4005         0.04361412,     -0.007001822,  0.09631092,    -0.06702025,
4006         -0.042049985,   -0.035070654,  -0.04103342,   -0.10273396,
4007         0.0544271,      0.037184782,   -0.13150354,   -0.0058036847,
4008         -0.008264958,   0.042035464,   0.05891794,    0.029673764,
4009         0.0063542654,   0.044788733,   0.054816857,   0.062257513,
4010         -0.00093483756, 0.048938446,   -0.004952862,  -0.007730018,
4011         -0.04043371,    -0.017094059,  0.07229206,    -0.023670016,
4012         -0.052195564,   -0.025616996,  -0.01520939,   0.045104615,
4013         -0.007376126,   0.003533447,   0.006570588,   0.056037236,
4014         0.12436656,     0.051817212,   0.028532185,   -0.08686856,
4015         0.11868599,     0.07663395,    -0.07323171,   0.03463402,
4016         -0.050708205,   -0.04458982,   -0.11590894,   0.021273347,
4017         0.1251325,      -0.15313013,   -0.12224372,   0.17228661,
4018         0.023029093,    0.086124025,   0.006445803,   -0.03496501,
4019         0.028332196,    0.04449512,    -0.042436164,  -0.026587414,
4020         -0.006041347,   -0.09292539,   -0.05678812,   0.03897832,
4021         0.09465633,     0.008115513,   -0.02171956,   0.08304309,
4022         0.071401566,    0.019622514,   0.032163795,   -0.004167056,
4023         0.02295182,     0.030739572,   0.056506045,   0.004612461,
4024         0.06524936,     0.059999723,   0.046395954,   -0.0045512207,
4025         -0.1335546,     -0.030136576,  0.11584653,    -0.014678886,
4026         0.0020118146,   -0.09688814,   -0.0790206,    0.039770417,
4027         -0.0329582,     0.07922767,    0.029322514,   0.026405897,
4028         0.04207835,     -0.07073373,   0.063781224,   0.0859677,
4029         -0.10925287,    -0.07011058,   0.048005477,   0.03438226,
4030         -0.09606514,    -0.006669445,  -0.043381985,  0.04240257,
4031         -0.06955775,    -0.06769346,   0.043903265,   -0.026784198,
4032         -0.017840602,   0.024307009,   -0.040079936,  -0.019946516,
4033         0.045318738,    -0.12233574,   0.026170589,   0.0074471775,
4034         0.15978073,     0.10185836,    0.10298046,    -0.015476589,
4035         -0.039390966,   -0.072174534,  0.0739445,     -0.1211869,
4036         -0.0347889,     -0.07943156,   0.014809798,   -0.12412325,
4037         -0.0030663363,  0.039695457,   0.0647603,     -0.08291318,
4038         -0.018529687,   -0.004423833,  0.0037507233,  0.084633216,
4039         -0.01514876,    -0.056505352,  -0.012800942,  -0.06994386,
4040         0.012962922,    -0.031234352,  0.07029052,    0.016418684,
4041         0.03618972,     0.055686004,   -0.08663945,   -0.017404709,
4042         -0.054761406,   0.029065743,   0.052404847,   0.020238016,
4043         0.0048197987,   -0.0214882,    0.07078733,    0.013016777,
4044         0.06262858,     0.009184685,   0.020785125,   -0.043904778,
4045         -0.0270329,     -0.03299152,   -0.060088247,  -0.015162964,
4046         -0.001828936,   0.12642565,    -0.056757294,  0.013586685,
4047         0.09232601,     -0.035886683,  0.06000002,    0.05229691,
4048         -0.052580316,   -0.082029596,  -0.010794592,  0.012947712,
4049         -0.036429964,   -0.085508935,  -0.13127148,   -0.017744139,
4050         0.031502828,    0.036232427,   -0.031581745,  0.023051167,
4051         -0.05325106,    -0.03421577,   0.028793324,   -0.034633752,
4052         -0.009881397,   -0.043551125,  -0.018609839,  0.0019097115,
4053         -0.008799762,   0.056595087,   0.0022273948,  0.055752404};
4054 
4055     recurrent_to_forget_weights_ = {
4056         -0.057784554,  -0.026057621,  -0.068447545,   -0.022581743,
4057         0.14811787,    0.10826372,    0.09471067,     0.03987225,
4058         -0.0039523416, 0.00030638507, 0.053185795,    0.10572994,
4059         0.08414449,    -0.022036452,  -0.00066928595, -0.09203576,
4060         0.032950465,   -0.10985798,   -0.023809856,   0.0021431844,
4061         -0.02196096,   -0.00326074,   0.00058621005,  -0.074678116,
4062         -0.06193199,   0.055729095,   0.03736828,     0.020123724,
4063         0.061878487,   -0.04729229,   0.034919553,    -0.07585433,
4064         -0.04421272,   -0.044019096,  0.085488975,    0.04058006,
4065         -0.06890133,   -0.030951202,  -0.024628663,   -0.07672815,
4066         0.034293607,   0.08556707,    -0.05293577,    -0.033561368,
4067         -0.04899627,   0.0241671,     0.015736353,    -0.095442444,
4068         -0.029564252,  0.016493602,   -0.035026584,   0.022337519,
4069         -0.026871363,  0.004780428,   0.0077918363,   -0.03601621,
4070         0.016435321,   -0.03263031,   -0.09543275,    -0.047392778,
4071         0.013454138,   0.028934088,   0.01685226,     -0.086110644,
4072         -0.046250615,  -0.01847454,   0.047608484,    0.07339695,
4073         0.034546845,   -0.04881143,   0.009128804,    -0.08802852,
4074         0.03761666,    0.008096139,   -0.014454086,   0.014361001,
4075         -0.023502491,  -0.0011840804, -0.07607001,    0.001856849,
4076         -0.06509276,   -0.006021153,  -0.08570962,    -0.1451793,
4077         0.060212336,   0.055259194,   0.06974018,     0.049454916,
4078         -0.027794661,  -0.08077226,   -0.016179763,   0.1169753,
4079         0.17213494,    -0.0056326236, -0.053934924,   -0.0124349,
4080         -0.11520337,   0.05409887,    0.088759385,    0.0019655675,
4081         0.0042065294,  0.03881498,    0.019844765,    0.041858196,
4082         -0.05695512,   0.047233116,   0.038937137,    -0.06542224,
4083         0.014429736,   -0.09719407,   0.13908425,     -0.05379757,
4084         0.012321099,   0.082840554,   -0.029899208,   0.044217527,
4085         0.059855383,   0.07711018,    -0.045319796,   0.0948846,
4086         -0.011724666,  -0.0033288454, -0.033542685,   -0.04764985,
4087         -0.13873616,   0.040668588,   0.034832682,    -0.015319203,
4088         -0.018715994,  0.046002675,   0.0599172,      -0.043107376,
4089         0.0294216,     -0.002314414,  -0.022424703,   0.0030315618,
4090         0.0014641669,  0.0029166266,  -0.11878115,    0.013738511,
4091         0.12375372,    -0.0006038222, 0.029104086,    0.087442465,
4092         0.052958444,   0.07558703,    0.04817258,     0.044462286,
4093         -0.015213451,  -0.08783778,   -0.0561384,     -0.003008196,
4094         0.047060397,   -0.002058388,  0.03429439,     -0.018839769,
4095         0.024734668,   0.024614193,   -0.042046934,   0.09597743,
4096         -0.0043254104, 0.04320769,    0.0064070094,   -0.0019131786,
4097         -0.02558259,   -0.022822596,  -0.023273505,   -0.02464396,
4098         -0.10991725,   -0.006240552,  0.0074488563,   0.024044557,
4099         0.04383914,    -0.046476185,  0.028658995,    0.060410924,
4100         0.050786525,   0.009452605,   -0.0073054377,  -0.024810238,
4101         0.0052906186,  0.0066939713,  -0.0020913032,  0.014515517,
4102         0.015898481,   0.021362653,   -0.030262267,   0.016587038,
4103         -0.011442813,  0.041154444,   -0.007631438,   -0.03423484,
4104         -0.010977775,  0.036152758,   0.0066366293,   0.11915515,
4105         0.02318443,    -0.041350313,  0.021485701,    -0.10906167,
4106         -0.028218046,  -0.00954771,   0.020531068,    -0.11995105,
4107         -0.03672871,   0.024019798,   0.014255957,    -0.05221243,
4108         -0.00661567,   -0.04630967,   0.033188973,    0.10107534,
4109         -0.014027541,  0.030796422,   -0.10270911,    -0.035999842,
4110         0.15443139,    0.07684145,    0.036571592,    -0.035900835,
4111         -0.0034699554, 0.06209149,    0.015920248,    -0.031122351,
4112         -0.03858649,   0.01849943,    0.13872518,     0.01503974,
4113         0.069941424,   -0.06948533,   -0.0088794185,  0.061282158,
4114         -0.047401894,  0.03100163,    -0.041533746,   -0.10430945,
4115         0.044574402,   -0.01425562,   -0.024290353,   0.034563623,
4116         0.05866852,    0.023947537,   -0.09445152,    0.035450947,
4117         0.02247216,    -0.0042998926, 0.061146557,    -0.10250651,
4118         0.020881841,   -0.06747029,   0.10062043,     -0.0023941975,
4119         0.03532124,    -0.016341697,  0.09685456,     -0.016764693,
4120         0.051808182,   0.05875331,    -0.04536488,    0.001626336,
4121         -0.028892258,  -0.01048663,   -0.009793449,   -0.017093895,
4122         0.010987891,   0.02357273,    -0.00010856845, 0.0099760275,
4123         -0.001845119,  -0.03551521,   0.0018358806,   0.05763657,
4124         -0.01769146,   0.040995963,   0.02235177,     -0.060430344,
4125         0.11475477,    -0.023854522,  0.10071741,     0.0686208,
4126         -0.014250481,  0.034261297,   0.047418304,    0.08562733,
4127         -0.030519066,  0.0060542435,  0.014653856,    -0.038836084,
4128         0.04096551,    0.032249358,   -0.08355519,    -0.026823482,
4129         0.056386515,   -0.010401743,  -0.028396193,   0.08507674,
4130         0.014410365,   0.020995233,   0.17040324,     0.11511526,
4131         0.02459721,    0.0066619175,  0.025853224,    -0.023133837,
4132         -0.081302024,  0.017264642,   -0.009585969,   0.09491168,
4133         -0.051313367,  0.054532815,   -0.014298593,   0.10657464,
4134         0.007076659,   0.10964551,    0.0409152,      0.008275321,
4135         -0.07283536,   0.07937492,    0.04192024,     -0.1075027};
4136 
4137     recurrent_to_output_weights_ = {
4138         0.025825322,   -0.05813119,   0.09495884,     -0.045984812,
4139         -0.01255415,   -0.0026479573, -0.08196161,    -0.054914974,
4140         -0.0046604523, -0.029587349,  -0.044576716,   -0.07480124,
4141         -0.082868785,  0.023254942,   0.027502948,    -0.0039728214,
4142         -0.08683098,   -0.08116779,   -0.014675607,   -0.037924774,
4143         -0.023314456,  -0.007401714,  -0.09255757,    0.029460307,
4144         -0.08829125,   -0.005139627,  -0.08989442,    -0.0555066,
4145         0.13596267,    -0.025062224,  -0.048351806,   -0.03850004,
4146         0.07266485,    -0.022414139,  0.05940088,     0.075114764,
4147         0.09597592,    -0.010211725,  -0.0049794707,  -0.011523867,
4148         -0.025980417,  0.072999895,   0.11091378,     -0.081685916,
4149         0.014416728,   0.043229222,   0.034178585,    -0.07530371,
4150         0.035837382,   -0.085607,     -0.007721233,   -0.03287832,
4151         -0.043848954,  -0.06404588,   -0.06632928,    -0.073643476,
4152         0.008214239,   -0.045984086,  0.039764922,    0.03474462,
4153         0.060612556,   -0.080590084,  0.049127717,    0.04151091,
4154         -0.030063879,  0.008801774,   -0.023021035,   -0.019558564,
4155         0.05158114,    -0.010947698,  -0.011825728,   0.0075720972,
4156         0.0699727,     -0.0039981045, 0.069350146,    0.08799282,
4157         0.016156472,   0.035502106,   0.11695009,     0.006217345,
4158         0.13392477,    -0.037875112,  0.025745004,    0.08940699,
4159         -0.00924166,   0.0046702605,  -0.036598757,   -0.08811812,
4160         0.10522024,    -0.032441203,  0.008176899,    -0.04454919,
4161         0.07058152,    0.0067963637,  0.039206743,    0.03259838,
4162         0.03725492,    -0.09515802,   0.013326398,    -0.052055415,
4163         -0.025676316,  0.03198509,    -0.015951829,   -0.058556724,
4164         0.036879618,   0.043357447,   0.028362012,    -0.05908629,
4165         0.0059240665,  -0.04995891,   -0.019187413,   0.0276265,
4166         -0.01628143,   0.0025863599,  0.08800015,     0.035250366,
4167         -0.022165963,  -0.07328642,   -0.009415526,   -0.07455109,
4168         0.11690406,    0.0363299,     0.07411125,     0.042103454,
4169         -0.009660886,  0.019076364,   0.018299393,    -0.046004917,
4170         0.08891175,    0.0431396,     -0.026327137,   -0.051502608,
4171         0.08979574,    -0.051670972,  0.04940282,     -0.07491107,
4172         -0.021240504,  0.022596184,   -0.034280192,   0.060163025,
4173         -0.058211457,  -0.051837247,  -0.01349775,    -0.04639988,
4174         -0.035936575,  -0.011681591,  0.064818054,    0.0073146066,
4175         -0.021745546,  -0.043124277,  -0.06471268,    -0.07053354,
4176         -0.029321948,  -0.05330136,   0.016933719,    -0.053782392,
4177         0.13747959,    -0.1361751,    -0.11569455,    0.0033329215,
4178         0.05693899,    -0.053219706,  0.063698,       0.07977434,
4179         -0.07924483,   0.06936997,    0.0034815092,   -0.007305279,
4180         -0.037325785,  -0.07251102,   -0.033633437,   -0.08677009,
4181         0.091591336,   -0.14165086,   0.021752775,    0.019683983,
4182         0.0011612234,  -0.058154266,  0.049996935,    0.0288841,
4183         -0.0024567875, -0.14345716,   0.010955264,    -0.10234828,
4184         0.1183656,     -0.0010731248, -0.023590032,   -0.072285876,
4185         -0.0724771,    -0.026382286,  -0.0014920527,  0.042667855,
4186         0.0018776858,  0.02986552,    0.009814309,    0.0733756,
4187         0.12289186,    0.018043943,   -0.0458958,     0.049412545,
4188         0.033632483,   0.05495232,    0.036686596,    -0.013781798,
4189         -0.010036754,  0.02576849,    -0.08307328,    0.010112348,
4190         0.042521734,   -0.05869831,   -0.071689695,   0.03876447,
4191         -0.13275425,   -0.0352966,    -0.023077697,   0.10285965,
4192         0.084736146,   0.15568255,    -0.00040734606, 0.027835453,
4193         -0.10292561,   -0.032401145,  0.10053256,     -0.026142767,
4194         -0.08271222,   -0.0030240538, -0.016368777,   0.1070414,
4195         0.042672627,   0.013456989,   -0.0437609,     -0.022309763,
4196         0.11576483,    0.04108048,    0.061026827,    -0.0190714,
4197         -0.0869359,    0.037901703,   0.0610107,      0.07202949,
4198         0.01675338,    0.086139716,   -0.08795751,    -0.014898893,
4199         -0.023771819,  -0.01965048,   0.007955471,    -0.043740474,
4200         0.03346837,    -0.10549954,   0.090567775,    0.042013682,
4201         -0.03176985,   0.12569028,    -0.02421228,    -0.029526481,
4202         0.023851605,   0.031539805,   0.05292009,     -0.02344001,
4203         -0.07811758,   -0.08834428,   0.10094801,     0.16594367,
4204         -0.06861939,   -0.021256343,  -0.041093912,   -0.06669611,
4205         0.035498552,   0.021757556,   -0.09302526,    -0.015403468,
4206         -0.06614931,   -0.051798206,  -0.013874718,   0.03630673,
4207         0.010412845,   -0.08077351,   0.046185967,    0.0035662893,
4208         0.03541868,    -0.094149634,  -0.034814864,   0.003128424,
4209         -0.020674974,  -0.03944324,   -0.008110165,   -0.11113267,
4210         0.08484226,    0.043586485,   0.040582247,    0.0968012,
4211         -0.065249965,  -0.028036479,  0.0050708856,   0.0017462453,
4212         0.0326779,     0.041296225,   0.09164146,     -0.047743853,
4213         -0.015952192,  -0.034451712,  0.084197424,    -0.05347844,
4214         -0.11768019,   0.085926116,   -0.08251791,    -0.045081906,
4215         0.0948852,     0.068401024,   0.024856757,    0.06978981,
4216         -0.057309967,  -0.012775832,  -0.0032452994,  0.01977615,
4217         -0.041040014,  -0.024264973,  0.063464895,    0.05431621,
4218     };
4219 
4220     cell_to_input_weights_ = {
4221         0.040369894, 0.030746894,  0.24704495,  0.018586371,  -0.037586458,
4222         -0.15312155, -0.11812848,  -0.11465643, 0.20259799,   0.11418174,
4223         -0.10116027, -0.011334949, 0.12411352,  -0.076769054, -0.052169047,
4224         0.21198851,  -0.38871562,  -0.09061183, -0.09683246,  -0.21929175};
4225 
4226     cell_to_forget_weights_ = {
4227         -0.01998659,  -0.15568835,  -0.24248174,   -0.012770197, 0.041331276,
4228         -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
4229         -0.047248036, 0.021479502,  0.033189066,   0.11952997,   -0.020432774,
4230         0.64658105,   -0.06650122,  -0.03467612,   0.095340036,  0.23647355};
4231 
4232     cell_to_output_weights_ = {
4233         0.08286371,  -0.08261836, -0.51210177, 0.002913762, 0.17764764,
4234         -0.5495371,  -0.08460716, -0.24552552, 0.030037103, 0.04123544,
4235         -0.11940523, 0.007358328, 0.1890978,   0.4833202,   -0.34441817,
4236         0.36312827,  -0.26375428, 0.1457655,   -0.19724406, 0.15548733};
4237 
4238     projection_weights_ = {
4239         -0.009802181, 0.09401916,   0.0717386,     -0.13895074,
4240         0.09641832,   0.060420845,  0.08539281,    0.054285463,
4241         0.061395317,  0.034448683,  -0.042991187,  0.019801661,
4242         -0.16840284,  -0.015726732, -0.23041931,   -0.024478018,
4243         -0.10959692,  -0.013875541, 0.18600968,    -0.061274476,
4244         0.0138165,    -0.08160894,  -0.07661644,   0.032372914,
4245         0.16169067,   0.22465782,   -0.03993472,   -0.004017731,
4246         0.08633481,   -0.28869787,  0.08682067,    0.17240396,
4247         0.014975425,  0.056431185,  0.031037588,   0.16702051,
4248         0.0077946745, 0.15140012,   0.29405436,    0.120285,
4249         -0.188994,    -0.027265169, 0.043389652,   -0.022061434,
4250         0.014777949,  -0.20203483,  0.094781205,   0.19100232,
4251         0.13987629,   -0.036132768, -0.06426278,   -0.05108664,
4252         0.13221376,   0.009441198,  -0.16715929,   0.15859416,
4253         -0.040437475, 0.050779544,  -0.022187516,  0.012166504,
4254         0.027685808,  -0.07675938,  -0.0055694645, -0.09444123,
4255         0.0046453946, 0.050794356,  0.10770313,    -0.20790008,
4256         -0.07149004,  -0.11425117,  0.008225835,   -0.035802525,
4257         0.14374903,   0.15262283,   0.048710253,   0.1847461,
4258         -0.007487823, 0.11000021,   -0.09542012,   0.22619456,
4259         -0.029149994, 0.08527916,   0.009043713,   0.0042746216,
4260         0.016261552,  0.022461696,  0.12689082,    -0.043589946,
4261         -0.12035478,  -0.08361797,  -0.050666027,  -0.1248618,
4262         -0.1275799,   -0.071875185, 0.07377272,    0.09944291,
4263         -0.18897448,  -0.1593054,   -0.06526116,   -0.040107165,
4264         -0.004618631, -0.067624845, -0.007576253,  0.10727444,
4265         0.041546922,  -0.20424393,  0.06907816,    0.050412357,
4266         0.00724631,   0.039827548,  0.12449835,    0.10747581,
4267         0.13708383,   0.09134148,   -0.12617786,   -0.06428341,
4268         0.09956831,   0.1208086,    -0.14676677,   -0.0727722,
4269         0.1126304,    0.010139365,  0.015571211,   -0.038128063,
4270         0.022913318,  -0.042050496, 0.16842307,    -0.060597885,
4271         0.10531834,   -0.06411776,  -0.07451711,   -0.03410368,
4272         -0.13393489,  0.06534304,   0.003620307,   0.04490757,
4273         0.05970546,   0.05197996,   0.02839995,    0.10434969,
4274         -0.013699693, -0.028353551, -0.07260381,   0.047201227,
4275         -0.024575593, -0.036445823, 0.07155557,    0.009672501,
4276         -0.02328883,  0.009533515,  -0.03606021,   -0.07421458,
4277         -0.028082801, -0.2678904,   -0.13221288,   0.18419984,
4278         -0.13012612,  -0.014588381, -0.035059117,  -0.04824723,
4279         0.07830115,   -0.056184657, 0.03277091,    0.025466874,
4280         0.14494097,   -0.12522776,  -0.098633975,  -0.10766018,
4281         -0.08317623,  0.08594209,   0.07749552,    0.039474737,
4282         0.1776665,    -0.07409566,  -0.0477268,    0.29323658,
4283         0.10801441,   0.1154011,    0.013952499,   0.10739139,
4284         0.10708251,   -0.051456142, 0.0074137426,  -0.10430189,
4285         0.10034707,   0.045594677,  0.0635285,     -0.0715442,
4286         -0.089667566, -0.10811871,  0.00026344223, 0.08298446,
4287         -0.009525053, 0.006585689,  -0.24567553,   -0.09450807,
4288         0.09648481,   0.026996298,  -0.06419476,   -0.04752702,
4289         -0.11063944,  -0.23441927,  -0.17608605,   -0.052156363,
4290         0.067035615,  0.19271925,   -0.0032889997, -0.043264326,
4291         0.09663576,   -0.057112187, -0.10100678,   0.0628376,
4292         0.04447668,   0.017961001,  -0.10094388,   -0.10190601,
4293         0.18335468,   0.10494553,   -0.052095775,  -0.0026118709,
4294         0.10539724,   -0.04383912,  -0.042349473,  0.08438151,
4295         -0.1947263,   0.02251204,   0.11216432,    -0.10307853,
4296         0.17351969,   -0.039091777, 0.08066188,    -0.00561982,
4297         0.12633002,   0.11335965,   -0.0088127935, -0.019777594,
4298         0.06864014,   -0.059751723, 0.016233567,   -0.06894641,
4299         -0.28651384,  -0.004228674, 0.019708522,   -0.16305895,
4300         -0.07468996,  -0.0855457,   0.099339016,   -0.07580735,
4301         -0.13775392,  0.08434318,   0.08330512,    -0.12131499,
4302         0.031935584,  0.09180414,   -0.08876437,   -0.08049874,
4303         0.008753825,  0.03498998,   0.030215185,   0.03907079,
4304         0.089751154,  0.029194152,  -0.03337423,   -0.019092513,
4305         0.04331237,   0.04299654,   -0.036394123,  -0.12915532,
4306         0.09793732,   0.07512415,   -0.11319543,   -0.032502122,
4307         0.15661901,   0.07671967,   -0.005491124,  -0.19379048,
4308         -0.218606,    0.21448623,   0.017840758,   0.1416943,
4309         -0.07051762,  0.19488361,   0.02664691,    -0.18104725,
4310         -0.09334311,  0.15026465,   -0.15493552,   -0.057762887,
4311         -0.11604192,  -0.262013,    -0.01391798,   0.012185008,
4312         0.11156489,   -0.07483202,  0.06693364,    -0.26151478,
4313         0.046425626,  0.036540434,  -0.16435726,   0.17338543,
4314         -0.21401681,  -0.11385144,  -0.08283257,   -0.069031075,
4315         0.030635102,  0.010969227,  0.11109743,    0.010919218,
4316         0.027526086,  0.13519906,   0.01891392,    -0.046839405,
4317         -0.040167913, 0.017953383,  -0.09700955,   0.0061885654,
4318         -0.07000971,  0.026893595,  -0.038844477,  0.14543656};
4319 
4320     lstm_input_ = {
4321         {// Batch0: 4 (input_sequence_size) * 5 (n_input)
4322          0.787926, 0.151646, 0.071352, 0.118426, 0.458058,   // step 0
4323          0.596268, 0.998386, 0.568695, 0.864524, 0.571277,   // step 1
4324          0.073204, 0.296072, 0.743333, 0.069199, 0.045348,   // step 2
4325          0.867394, 0.291279, 0.013714, 0.482521, 0.626339},  // step 3
4326 
4327         {// Batch1: 4 (input_sequence_size) * 5 (n_input)
4328          0.295743, 0.544053, 0.690064, 0.858138, 0.497181,  // step 0
4329          0.642421, 0.524260, 0.134799, 0.003639, 0.162482,  // step 1
4330          0.640394, 0.930399, 0.050782, 0.432485, 0.988078,  // step 2
4331          0.082922, 0.563329, 0.865614, 0.333232, 0.259916}  // step 3
4332     };
4333 
4334     lstm_golden_output_ = {
4335         {// Batch0: 4 (input_sequence_size) * 16 (n_output)
4336          -0.00396806, 0.029352,     -0.00279226, 0.0159977,   -0.00835576,
4337          -0.0211779,  0.0283512,    -0.0114597,  0.00907307,  -0.0244004,
4338          -0.0152191,  -0.0259063,   0.00914318,  0.00415118,  0.017147,
4339          0.0134203,   -0.0166936,   0.0381209,   0.000889694, 0.0143363,
4340          -0.0328911,  -0.0234288,   0.0333051,   -0.012229,   0.0110322,
4341          -0.0457725,  -0.000832209, -0.0202817,  0.0327257,   0.0121308,
4342          0.0155969,   0.0312091,    -0.0213783,  0.0350169,   0.000324794,
4343          0.0276012,   -0.0263374,   -0.0371449,  0.0446149,   -0.0205474,
4344          0.0103729,   -0.0576349,   -0.0150052,  -0.0292043,  0.0376827,
4345          0.0136115,   0.0243435,    0.0354492,   -0.0189322,  0.0464512,
4346          -0.00251373, 0.0225745,    -0.0308346,  -0.0317124,  0.0460407,
4347          -0.0189395,  0.0149363,    -0.0530162,  -0.0150767,  -0.0340193,
4348          0.0286833,   0.00824207,   0.0264887,   0.0305169},
4349         {// Batch1: 4 (input_sequence_size) * 16 (n_output)
4350          -0.013869,    0.0287268,   -0.00334693, 0.00733398,  -0.0287926,
4351          -0.0186926,   0.0193662,   -0.0115437,  0.00422612,  -0.0345232,
4352          0.00223253,   -0.00957321, 0.0210624,   0.013331,    0.0150954,
4353          0.02168,      -0.0141913,  0.0322082,   0.00227024,  0.0260507,
4354          -0.0188721,   -0.0296489,  0.0399134,   -0.0160509,  0.0116039,
4355          -0.0447318,   -0.0150515,  -0.0277406,  0.0316596,   0.0118233,
4356          0.0214762,    0.0293641,   -0.0204549,  0.0450315,   -0.00117378,
4357          0.0167673,    -0.0375007,  -0.0238314,  0.038784,    -0.0174034,
4358          0.0131743,    -0.0506589,  -0.0048447,  -0.0240239,  0.0325789,
4359          0.00790065,   0.0220157,   0.0333314,   -0.0264787,  0.0387855,
4360          -0.000764675, 0.0217599,   -0.037537,   -0.0335206,  0.0431679,
4361          -0.0211424,   0.010203,    -0.062785,   -0.00832363, -0.025181,
4362          0.0412031,    0.0118723,   0.0239643,   0.0394009}};
4363   }
4364 };
4365 
TEST_F(NoCifgPeepholeProjectionClippingLstmTest,LstmBlackBoxTest)4366 TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
4367   const int n_batch = 2;
4368   const int n_input = 5;
4369   const int n_cell = 20;
4370   const int n_output = 16;
4371 
4372   LSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
4373                    /*use_cifg=*/false, /*use_peephole=*/true,
4374                    /*use_projection_weights=*/true,
4375                    /*use_projection_bias=*/false,
4376                    /*cell_clip=*/0.0, /*proj_clip=*/0.0,
4377                    {
4378                        {n_batch, n_input},  // input tensor
4379 
4380                        {n_cell, n_input},  // input_to_input_weight tensor
4381                        {n_cell, n_input},  // input_to_forget_weight tensor
4382                        {n_cell, n_input},  // input_to_cell_weight tensor
4383                        {n_cell, n_input},  // input_to_output_weight tensor
4384 
4385                        {n_cell, n_output},  // recurrent_to_input_weight tensor
4386                        {n_cell, n_output},  // recurrent_to_forget_weight tensor
4387                        {n_cell, n_output},  // recurrent_to_cell_weight tensor
4388                        {n_cell, n_output},  // recurrent_to_output_weight tensor
4389 
4390                        {n_cell},  // cell_to_input_weight tensor
4391                        {n_cell},  // cell_to_forget_weight tensor
4392                        {n_cell},  // cell_to_output_weight tensor
4393 
4394                        {n_cell},  // input_gate_bias tensor
4395                        {n_cell},  // forget_gate_bias tensor
4396                        {n_cell},  // cell_bias tensor
4397                        {n_cell},  // output_gate_bias tensor
4398 
4399                        {n_output, n_cell},  // projection_weight tensor
4400                        {0},                 // projection_bias tensor
4401 
4402                        {n_batch, n_output},  // activation_state tensor
4403                        {n_batch, n_cell},    // cell_state tensor
4404                    },
4405                    /*weight_type=*/TensorType_FLOAT32);
4406 
4407   lstm.SetInputToInputWeights(input_to_input_weights_);
4408   lstm.SetInputToCellWeights(input_to_cell_weights_);
4409   lstm.SetInputToForgetWeights(input_to_forget_weights_);
4410   lstm.SetInputToOutputWeights(input_to_output_weights_);
4411 
4412   lstm.SetInputGateBias(input_gate_bias_);
4413   lstm.SetCellBias(cell_gate_bias_);
4414   lstm.SetForgetGateBias(forget_gate_bias_);
4415   lstm.SetOutputGateBias(output_gate_bias_);
4416 
4417   lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
4418   lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
4419   lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
4420   lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
4421 
4422   lstm.SetCellToInputWeights(cell_to_input_weights_);
4423   lstm.SetCellToForgetWeights(cell_to_forget_weights_);
4424   lstm.SetCellToOutputWeights(cell_to_output_weights_);
4425 
4426   lstm.SetProjectionWeights(projection_weights_);
4427 
4428   VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
4429 }
4430 
4431 class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
4432     : public BaseLstmTest {
SetUp()4433   void SetUp() override {
4434     input_to_input_weights_ = {0.5,  0.6,  0.7,  -0.8, -0.9, 0.1,  0.2,
4435                                0.3,  -0.4, 0.5,  -0.8, 0.7,  -0.6, 0.5,
4436                                -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
4437 
4438     input_to_forget_weights_ = {-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2,
4439                                 -0.4, 0.3,  -0.8, -0.4, 0.3,  -0.5, -0.4,
4440                                 -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};
4441 
4442     input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5,  -0.2,
4443                               -0.3, -0.2, -0.6, 0.6,  -0.1, -0.4, -0.3,
4444                               -0.7, 0.7,  -0.9, -0.5, 0.8,  0.6};
4445 
4446     input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
4447                                 -0.3, -0.8, -0.2, 0.6,  -0.2, 0.4,  -0.7,
4448                                 -0.3, -0.5, 0.1,  0.5,  -0.6, -0.4};
4449 
4450     input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
4451 
4452     forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
4453 
4454     cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
4455 
4456     output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
4457 
4458     recurrent_to_input_weights_ = {-0.2, -0.3, 0.4,  0.1,  -0.5, 0.9,
4459                                    -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
4460 
4461     recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8,  -0.08,
4462                                   -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
4463 
4464     recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
4465                                     0.9,  0.3,  -0.1, 0.2,  0.5, 0.2};
4466 
4467     recurrent_to_output_weights_ = {0.3,  -0.1, 0.1,  -0.2, -0.5, -0.7,
4468                                     -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
4469 
4470     cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
4471 
4472     cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
4473 
4474     cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
4475 
4476     input_layer_norm_coefficients_ = {0.1, 0.2, 0.3, 0.5};
4477     forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
4478     cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
4479     output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
4480 
4481     projection_weights_ = {-0.1, 0.2,  0.01, -0.2, 0.1,  0.5,
4482                            0.3,  0.08, 0.07, 0.2,  -0.4, 0.2};
4483 
4484     lstm_input_ = {
4485         {// Batch0: 3 (input_sequence_size) * 5 (n_input)
4486          0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
4487          0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
4488          0.2, 0.7, 0.7, 0.1, 0.7},  // seq 2
4489 
4490         {// Batch1: 3 (input_sequence_size) * 5 (n_input)
4491          0.3, 0.2, 0.9, 0.8, 0.1,   // seq 0
4492          0.1, 0.5, 0.2, 0.4, 0.2,   // seq 1
4493          0.6, 0.9, 0.2, 0.5, 0.7},  // seq 2
4494     };
4495   }
4496 };
4497 
TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,LayerNormLstmBlackBoxTest)4498 TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
4499        LayerNormLstmBlackBoxTest) {
4500   const int n_batch = 2;
4501   const int n_input = 5;
4502   const int n_cell = 4;
4503   const int n_output = 3;
4504   const float ceil_clip = 0.0;
4505   const float proj_clip = 0.0;
4506 
4507   LSTMOpModel layer_norm_lstm(
4508       n_batch, n_input, n_cell, n_output,
4509       /*use_cifg=*/false, /*use_peephole=*/true,
4510       /*use_projection_weights=*/true,
4511       /*use_projection_bias=*/false, ceil_clip, proj_clip,
4512       {
4513           {n_batch, n_input},  // input tensor
4514 
4515           {n_cell, n_input},  // input_to_input_weight tensor
4516           {n_cell, n_input},  // input_to_forget_weight tensor
4517           {n_cell, n_input},  // input_to_cell_weight tensor
4518           {n_cell, n_input},  // input_to_output_weight tensor
4519 
4520           {n_cell, n_output},  // recurrent_to_input_weight tensor
4521           {n_cell, n_output},  // recurrent_to_forget_weight tensor
4522           {n_cell, n_output},  // recurrent_to_cell_weight tensor
4523           {n_cell, n_output},  // recurrent_to_output_weight tensor
4524 
4525           {n_cell},  // cell_to_input_weight tensor
4526           {n_cell},  // cell_to_forget_weight tensor
4527           {n_cell},  // cell_to_output_weight tensor
4528 
4529           {n_cell},  // input_gate_bias tensor
4530           {n_cell},  // forget_gate_bias tensor
4531           {n_cell},  // cell_bias tensor
4532           {n_cell},  // output_gate_bias tensor
4533 
4534           {n_output, n_cell},  // projection_weight tensor
4535           {0},                 // projection_bias tensor
4536 
4537           {n_batch, n_output},  // activation_state tensor
4538           {n_batch, n_cell},    // cell_state tensor
4539 
4540           {n_cell},  // input_layer_norm_coefficient tensor
4541           {n_cell},  // forget_layer_norm_coefficient tensor
4542           {n_cell},  // cell_layer_norm_coefficient tensor
4543           {n_cell},  // output_layer_norm_coefficient tensor
4544       },
4545       /*weight_type=*/TensorType_FLOAT32);
4546 
4547   layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
4548   layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
4549   layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
4550   layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
4551 
4552   layer_norm_lstm.SetInputGateBias(input_gate_bias_);
4553   layer_norm_lstm.SetCellBias(cell_gate_bias_);
4554   layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
4555   layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
4556 
4557   layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
4558   layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
4559   layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
4560   layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
4561 
4562   layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
4563   layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
4564   layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
4565 
4566   layer_norm_lstm.SetInputLayerNormCoefficients(input_layer_norm_coefficients_);
4567   layer_norm_lstm.SetForgetLayerNormCoefficients(
4568       forget_layer_norm_coefficients_);
4569   layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
4570   layer_norm_lstm.SetOutputLayerNormCoefficients(
4571       output_layer_norm_coefficients_);
4572 
4573   layer_norm_lstm.SetProjectionWeights(projection_weights_);
4574 
4575   // Verify the final output.
4576   const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
4577       {
4578           // Batch0: 3 (input_sequence_size) * 3 (n_output)
4579           0.0244077, 0.128027, -0.00170918,  // seq 0
4580           0.0137642, 0.140751, 0.0395835,    // seq 1
4581           -0.00459231, 0.155278, 0.0837377,  // seq 2
4582       },
4583       {
4584           // Batch1: 3 (input_sequence_size) * 3 (n_output)
4585           -0.00692428, 0.0848741, 0.063445,  // seq 0
4586           -0.00403912, 0.139963, 0.072681,   // seq 1
4587           0.00752706, 0.161903, 0.0561371,   // seq 2
4588       }};
4589 
4590   VerifyGoldens(lstm_input_, layer_norm_lstm_golden_output, &layer_norm_lstm);
4591 }
4592 
4593 class CifgPeepholeProjectionNoClippingLayerNormLstmTest : public BaseLstmTest {
SetUp()4594   void SetUp() override {
4595     input_to_forget_weights_ = {-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2,
4596                                 -0.4, 0.3,  -0.8, -0.4, 0.3,  -0.5, -0.4,
4597                                 -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};
4598     input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5,  -0.2,
4599                               -0.3, -0.2, -0.6, 0.6,  -0.1, -0.4, -0.3,
4600                               -0.7, 0.7,  -0.9, -0.5, 0.8,  0.6};
4601     input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
4602                                 -0.3, -0.8, -0.2, 0.6,  -0.2, 0.4,  -0.7,
4603                                 -0.3, -0.5, 0.1,  0.5,  -0.6, -0.4};
4604 
4605     forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
4606     cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
4607     output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
4608 
4609     recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8,  -0.08,
4610                                   -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
4611     recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
4612                                     0.9,  0.3,  -0.1, 0.2,  0.5, 0.2};
4613     recurrent_to_output_weights_ = {0.3,  -0.1, 0.1,  -0.2, -0.5, -0.7,
4614                                     -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
4615 
4616     cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
4617     cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
4618 
4619     forget_layer_norm_coefficients_ = {0.2, 0.2, 0.4, 0.3};
4620     cell_layer_norm_coefficients_ = {0.7, 0.2, 0.3, 0.8};
4621     output_layer_norm_coefficients_ = {0.6, 0.2, 0.2, 0.5};
4622     projection_weights_ = {-0.1, 0.2,  0.01, -0.2, 0.1,  0.5,
4623                            0.3,  0.08, 0.07, 0.2,  -0.4, 0.2};
4624 
4625     lstm_input_ = {
4626         {// Batch0: 3 (input_sequence_size) * 5 (n_input)
4627          0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
4628          0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
4629          0.2, 0.7, 0.7, 0.1, 0.7},  // seq 2
4630 
4631         {// Batch1: 3 (input_sequence_size) * 5 (n_input)
4632          0.3, 0.2, 0.9, 0.8, 0.1,   // seq 0
4633          0.1, 0.5, 0.2, 0.4, 0.2,   // seq 1
4634          0.6, 0.9, 0.2, 0.5, 0.7},  // seq 2
4635     };
4636   }
4637 };
4638 
TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,LayerNormLstmBlackBoxTest)4639 TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
4640        LayerNormLstmBlackBoxTest) {
4641   const int n_batch = 2;
4642   const int n_input = 5;
4643   const int n_cell = 4;
4644   const int n_output = 3;
4645   const float ceil_clip = 0.0;
4646   const float proj_clip = 0.0;
4647 
4648   LSTMOpModel layer_norm_lstm(
4649       n_batch, n_input, n_cell, n_output,
4650       /*use_cifg=*/true, /*use_peephole=*/true,
4651       /*use_projection_weights=*/true,
4652       /*use_projection_bias=*/false, ceil_clip, proj_clip,
4653       {
4654           {n_batch, n_input},  // input tensor
4655 
4656           {0, 0},             // input_to_input_weight tensor
4657           {n_cell, n_input},  // input_to_forget_weight tensor
4658           {n_cell, n_input},  // input_to_cell_weight tensor
4659           {n_cell, n_input},  // input_to_output_weight tensor
4660 
4661           {0, 0},              // recurrent_to_input_weight tensor
4662           {n_cell, n_output},  // recurrent_to_forget_weight tensor
4663           {n_cell, n_output},  // recurrent_to_cell_weight tensor
4664           {n_cell, n_output},  // recurrent_to_output_weight tensor
4665 
4666           {0},       // cell_to_input_weight tensor
4667           {n_cell},  // cell_to_forget_weight tensor
4668           {n_cell},  // cell_to_output_weight tensor
4669 
4670           {0},       // input_gate_bias tensor
4671           {n_cell},  // forget_gate_bias tensor
4672           {n_cell},  // cell_bias tensor
4673           {n_cell},  // output_gate_bias tensor
4674 
4675           {n_output, n_cell},  // projection_weight tensor
4676           {0},                 // projection_bias tensor
4677 
4678           {n_batch, n_output},  // activation_state tensor
4679           {n_batch, n_cell},    // cell_state tensor
4680 
4681           {0},       // input_layer_norm_coefficient tensor
4682           {n_cell},  // forget_layer_norm_coefficient tensor
4683           {n_cell},  // cell_layer_norm_coefficient tensor
4684           {n_cell},  // output_layer_norm_coefficient tensor
4685       },
4686       /*weight_type=*/TensorType_FLOAT32);
4687 
4688   layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
4689   layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
4690   layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
4691 
4692   layer_norm_lstm.SetCellBias(cell_gate_bias_);
4693   layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
4694   layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
4695 
4696   layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
4697   layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
4698   layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
4699 
4700   layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
4701   layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
4702 
4703   layer_norm_lstm.SetForgetLayerNormCoefficients(
4704       forget_layer_norm_coefficients_);
4705   layer_norm_lstm.SetCellLayerNormCoefficients(cell_layer_norm_coefficients_);
4706   layer_norm_lstm.SetOutputLayerNormCoefficients(
4707       output_layer_norm_coefficients_);
4708 
4709   layer_norm_lstm.SetProjectionWeights(projection_weights_);
4710 
4711   // Verify the final output.
4712   const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
4713       {
4714           // Batch0: 3 (input_sequence_size) * 3 (n_output)
4715           0.02129706, 0.140816242, 0.0112733059,     // seq 0
4716           0.0132302344, 0.152308047, 0.0346313119,   // seq 1
4717           -0.0123688057, 0.165790111, 0.0893077999,  // seq 2
4718       },
4719       {
4720           // Batch1: 3 (input_sequence_size) * 3 (n_output)
4721           -0.0226350538, 0.0916948169, 0.0769175813,  // seq 0
4722           -0.0269966982, 0.149707705, 0.094149217,    // seq 1
4723           -0.0103429332, 0.173016444, 0.0720508844,   // seq 2
4724       }};
4725 
4726   VerifyGoldens(lstm_input_, layer_norm_lstm_golden_output, &layer_norm_lstm);
4727 }
4728 
4729 class BaseReduceOpModel : public SingleOpModelWithNNAPI {
4730  public:
SetAxis(const std::vector<int> & data)4731   void SetAxis(const std::vector<int>& data) { PopulateTensor(axis_, data); }
4732 
4733   template <class T>
SetInput(const std::vector<T> & data)4734   void SetInput(const std::vector<T>& data) {
4735     PopulateTensor(input_, data);
4736   }
4737 
4738   template <class T>
GetOutput()4739   std::vector<T> GetOutput() {
4740     return ExtractVector<T>(output_);
4741   }
4742 
GetDequantizedOutput()4743   std::vector<float> GetDequantizedOutput() {
4744     return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
4745                                GetScale(output_), GetZeroPoint(output_));
4746   }
4747 
GetOutputShape()4748   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
4749 
Input()4750   int Input() { return input_; }
4751 
4752  protected:
4753   int input_;
4754   int axis_;
4755   int output_;
4756 };
4757 
4758 // Model for the tests case where axis is a dynamic tensor.
4759 class MeanOpDynamicModel : public BaseReduceOpModel {
4760  public:
MeanOpDynamicModel(const TensorData & input,const TensorData & output,const TensorData & axis,bool keep_dims)4761   MeanOpDynamicModel(const TensorData& input, const TensorData& output,
4762                      const TensorData& axis, bool keep_dims) {
4763     input_ = AddInput(input);
4764     axis_ = AddInput(axis);
4765     output_ = AddOutput(output);
4766     SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions,
4767                  CreateReducerOptions(builder_, keep_dims).Union());
4768     BuildInterpreterWithNNAPI({GetShape(input_)});
4769   }
4770 };
4771 
TEST(DynamicFloatMeanOpTest,NotKeepDims)4772 TEST(DynamicFloatMeanOpTest, NotKeepDims) {
4773   std::vector<float> data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
4774                              9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
4775                              17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
4776   MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
4777                        {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
4778                        false);
4779   std::vector<int> axis = {1, 0, -3, -3};
4780   m.SetAxis(axis);
4781   m.SetInput(data);
4782   ASSERT_EQ(m.Invoke(), kTfLiteOk);
4783   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
4784   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({12, 13})));
4785 }
4786 
4787 // Model for the tests case where axis is a const tensor.
4788 class MeanOpConstModel : public BaseReduceOpModel {
4789  public:
MeanOpConstModel(const TensorData & input,const TensorData & output,std::initializer_list<int> axis_shape,std::initializer_list<int> axis,bool keep_dims)4790   MeanOpConstModel(const TensorData& input, const TensorData& output,
4791                    std::initializer_list<int> axis_shape,
4792                    std::initializer_list<int> axis, bool keep_dims) {
4793     input_ = AddInput(input);
4794     axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
4795     output_ = AddOutput(output);
4796     SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions,
4797                  CreateReducerOptions(builder_, keep_dims).Union());
4798     BuildInterpreterWithNNAPI({GetShape(input_)});
4799   }
4800 };
4801 
4802 // Tests for reduce_mean
TEST(NNAPIDelegate,MeanFloatNotKeepDims)4803 TEST(NNAPIDelegate, MeanFloatNotKeepDims) {
4804   std::vector<float> data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
4805                              9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
4806                              17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
4807   MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
4808                      {4}, {1, 0, -3, -3}, false);
4809   m.SetInput(data);
4810   ASSERT_EQ(m.Invoke(), kTfLiteOk);
4811   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
4812   EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({12, 13}));
4813 }
4814 
TEST(NNAPIDelegate,MeanFloatKeepDims)4815 TEST(NNAPIDelegate, MeanFloatKeepDims) {
4816   std::vector<float> data = {1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,
4817                              9.0,  10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
4818                              17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
4819   MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
4820                      {2}, {0, 2}, true);
4821   m.SetInput(data);
4822   ASSERT_EQ(m.Invoke(), kTfLiteOk);
4823   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
4824   EXPECT_THAT(m.GetOutput<float>(), NnapiArrayFloatNear({10.5, 12.5, 14.5}));
4825 }
4826 
4827 class BaseEmbeddingLookupOpModel : public SingleOpModelWithNNAPI {
4828  public:
BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,std::initializer_list<int> weight_shape,TensorType weight_type=TensorType_FLOAT32)4829   BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
4830                              std::initializer_list<int> weight_shape,
4831                              TensorType weight_type = TensorType_FLOAT32) {
4832     input_ = AddInput(TensorType_INT32);
4833     weight_ = AddInput(weight_type);
4834     output_ = AddOutput(TensorType_FLOAT32);
4835     SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
4836     BuildInterpreterWithNNAPI({index_shape, weight_shape});
4837   }
4838 
SetInput(std::initializer_list<int> data)4839   void SetInput(std::initializer_list<int> data) {
4840     PopulateTensor(input_, data);
4841   }
4842 
GetOutput()4843   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
4844 
4845  protected:
4846   int input_;
4847   int weight_;
4848   int output_;
4849 };
4850 
4851 class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
4852  public:
4853   using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel;
4854 
Set3DWeightMatrix(const std::function<float (int,int,int)> & function)4855   void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
4856     TfLiteTensor* tensor = interpreter_->tensor(weight_);
4857     int rows = tensor->dims->data[0];
4858     int columns = tensor->dims->data[1];
4859     int features = tensor->dims->data[2];
4860     for (int i = 0; i < rows; i++) {
4861       for (int j = 0; j < columns; j++) {
4862         for (int k = 0; k < features; k++) {
4863           tensor->data.f[(i * columns + j) * features + k] = function(i, j, k);
4864         }
4865       }
4866     }
4867   }
4868 };
4869 
TEST(NNAPIDelegate,EmbeddingLookupSimpleTest)4870 TEST(NNAPIDelegate, EmbeddingLookupSimpleTest) {
4871   EmbeddingLookupOpModel m({3}, {3, 2, 4});
4872   m.SetInput({1, 0, 2});
4873   m.Set3DWeightMatrix(
4874       [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
4875 
4876   ASSERT_EQ(m.Invoke(), kTfLiteOk);
4877 
4878   EXPECT_THAT(m.GetOutput(),
4879               NnapiArrayFloatNear({
4880                   1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13,  // Row 1
4881                   0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13,  // Row 0
4882                   2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13,  // Row 2
4883               }));
4884 }
4885 
4886 class HashtableLookupOpModel : public SingleOpModelWithNNAPI {
4887  public:
HashtableLookupOpModel(std::initializer_list<int> lookup_shape,std::initializer_list<int> key_shape,std::initializer_list<int> value_shape,TensorType type)4888   HashtableLookupOpModel(std::initializer_list<int> lookup_shape,
4889                          std::initializer_list<int> key_shape,
4890                          std::initializer_list<int> value_shape,
4891                          TensorType type) {
4892     lookup_ = AddInput(TensorType_INT32);
4893     key_ = AddInput(TensorType_INT32);
4894     value_ = AddInput(type);
4895     output_ = AddOutput(type);
4896     hit_ = AddOutput(TensorType_UINT8);
4897     SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0);
4898     BuildInterpreterWithNNAPI({lookup_shape, key_shape, value_shape});
4899   }
4900 
SetLookup(std::initializer_list<int> data)4901   void SetLookup(std::initializer_list<int> data) {
4902     PopulateTensor<int>(lookup_, data);
4903   }
4904 
SetHashtableKey(std::initializer_list<int> data)4905   void SetHashtableKey(std::initializer_list<int> data) {
4906     PopulateTensor<int>(key_, data);
4907   }
4908 
SetHashtableValue(const std::vector<string> & content)4909   void SetHashtableValue(const std::vector<string>& content) {
4910     PopulateStringTensor(value_, content);
4911   }
4912 
SetHashtableValue(const std::function<float (int)> & function)4913   void SetHashtableValue(const std::function<float(int)>& function) {
4914     TfLiteTensor* tensor = interpreter_->tensor(value_);
4915     int rows = tensor->dims->data[0];
4916     for (int i = 0; i < rows; i++) {
4917       tensor->data.f[i] = function(i);
4918     }
4919   }
4920 
SetHashtableValue(const std::function<float (int,int)> & function)4921   void SetHashtableValue(const std::function<float(int, int)>& function) {
4922     TfLiteTensor* tensor = interpreter_->tensor(value_);
4923     int rows = tensor->dims->data[0];
4924     int features = tensor->dims->data[1];
4925     for (int i = 0; i < rows; i++) {
4926       for (int j = 0; j < features; j++) {
4927         tensor->data.f[i * features + j] = function(i, j);
4928       }
4929     }
4930   }
4931 
GetStringOutput()4932   std::vector<string> GetStringOutput() {
4933     TfLiteTensor* output = interpreter_->tensor(output_);
4934     int num = GetStringCount(output);
4935     std::vector<string> result(num);
4936     for (int i = 0; i < num; i++) {
4937       auto ref = GetString(output, i);
4938       result[i] = string(ref.str, ref.len);
4939     }
4940     return result;
4941   }
4942 
GetOutput()4943   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetHit()4944   std::vector<uint8_t> GetHit() { return ExtractVector<uint8_t>(hit_); }
4945 
4946  private:
4947   int lookup_;
4948   int key_;
4949   int value_;
4950   int output_;
4951   int hit_;
4952 };
4953 
TEST(NNAPIDelegate,HashtableLookupTest2DInput)4954 TEST(NNAPIDelegate, HashtableLookupTest2DInput) {
4955   HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32);
4956 
4957   m.SetLookup({1234, -292, -11, 0});
4958   m.SetHashtableKey({-11, 0, 1234});
4959   m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
4960 
4961   ASSERT_EQ(m.Invoke(), kTfLiteOk);
4962 
4963   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
4964                                  2.0, 2.1,  // 2-nd item
4965                                  0, 0,      // Not found
4966                                  0.0, 0.1,  // 0-th item
4967                                  1.0, 1.1,  // 1-st item
4968                              }));
4969   EXPECT_THAT(m.GetHit(), ElementsAreArray({
4970                               1,
4971                               0,
4972                               1,
4973                               1,
4974                           }));
4975 }
4976 
TEST(NNAPIDelegate,HashtableLookupTest1DInput)4977 TEST(NNAPIDelegate, HashtableLookupTest1DInput) {
4978   HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32);
4979 
4980   m.SetLookup({1234, -292, -11, 0});
4981   m.SetHashtableKey({-11, 0, 1234});
4982   m.SetHashtableValue([](int i) { return i * i / 10.0f; });
4983 
4984   ASSERT_EQ(m.Invoke(), kTfLiteOk);
4985 
4986   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
4987                                  0.4,  // 2-nd item
4988                                  0,    // Not found
4989                                  0.0,  // 0-th item
4990                                  0.1,  // 1-st item
4991                              }));
4992   EXPECT_THAT(m.GetHit(), ElementsAreArray({
4993                               1,
4994                               0,
4995                               1,
4996                               1,
4997                           }));
4998 }
4999 
5000 // A base class of PRelu op model. It provides the constructor for
5001 // FloatPReluOpModel and QuantizedPReluOpModel.
5002 class PReluOpModel : public SingleOpModelWithNNAPI {
5003  public:
PReluOpModel(const TensorData & input,const TensorData & alpha)5004   PReluOpModel(const TensorData& input, const TensorData& alpha)
5005       : input_type_(input.type) {
5006     input_ = AddInput(input);
5007     alpha_ = AddInput(alpha);
5008     output_ = AddOutput({input.type, input.shape, input.min, input.max});
5009     SetBuiltinOp(BuiltinOperator_PRELU, BuiltinOptions_NONE, 0);
5010     BuildInterpreterWithNNAPI({GetShape(input_), GetShape(alpha_)});
5011   }
5012 
SetInput(std::initializer_list<float> data)5013   void SetInput(std::initializer_list<float> data) {
5014     SetData(input_, input_type_, data);
5015   }
5016 
SetAlpha(std::initializer_list<float> data)5017   void SetAlpha(std::initializer_list<float> data) {
5018     SetData(alpha_, input_type_, data);
5019   }
5020 
GetOutput()5021   std::vector<float> GetOutput() {
5022     std::vector<float> output;
5023     GetData(output_, input_type_, &output);
5024     return output;
5025   }
5026 
5027  protected:
5028   int input_;
5029   int alpha_;
5030   int output_;
5031 
5032   const TensorType input_type_;
5033 };
5034 
TEST(NNAPIDelegate,PReluFloat)5035 TEST(NNAPIDelegate, PReluFloat) {
5036   PReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}},
5037                  {TensorType_FLOAT32, {1, 1, 3}});
5038 
5039   m.SetInput({
5040       0.0f, 0.0f, 0.0f,     // Row 1, Column 1
5041       1.0f, 1.0f, 1.0f,     // Row 1, Column 2
5042       -1.0f, -1.0f, -1.0f,  // Row 2, Column 1
5043       -2.0f, -2.0f, -2.0f,  // Row 1, Column 2
5044   });
5045   m.SetAlpha({0.0f, 1.0f, 2.0f});
5046   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5047   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({
5048                                  0.0f, 0.0f, 0.0f,    // Row 1, Column 1
5049                                  1.0f, 1.0f, 1.0f,    // Row 1, Column 2
5050                                  0.0f, -1.0f, -2.0f,  // Row 2, Column 1
5051                                  0.0f, -2.0f, -4.0f,  // Row 1, Column 2
5052                              }));
5053 }
5054 
TEST(NNAPIDelegate,PReluQuantized)5055 TEST(NNAPIDelegate, PReluQuantized) {
5056   const float kMin = -1;
5057   const float kMax = 127.f / 128.f;
5058   PReluOpModel m({TensorType_UINT8, {1, 2, 2, 3}, kMin, kMax},
5059                  {TensorType_UINT8, {1, 1, 3}, kMin, kMax});
5060   m.SetInput({
5061       0.0f, 0.0f, 0.0f,        // Row 1, Column 1
5062       0.5f, 0.5f, 0.5f,        // Row 1, Column 2
5063       -1.0f, -1.0f, -1.0f,     // Row 2, Column 1
5064       -0.25f, -0.25f, -0.25f,  // Row 1, Column 2
5065   });
5066   m.SetAlpha({0.0f, 0.5f, -0.5f});
5067   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5068   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
5069                                  {
5070                                      0.0f, 0.0f, 0.0f,       // Row 1, Column 1
5071                                      0.5f, 0.5f, 0.5f,       // Row 1, Column 2
5072                                      0.0f, -0.5f, 0.5f,      // Row 2, Column 1
5073                                      0.0f, -0.125f, 0.125f,  // Row 1, Column 2
5074                                  },
5075                                  kQuantizedTolerance)));
5076 }
5077 
5078 // Tests case where paddings is a const tensor. Type T is the dtype.
5079 template <typename T1>
5080 class PadV2OpConstModel : public PadOpModel<T1> {
5081  public:
PadV2OpConstModel(const TensorData & input,std::initializer_list<int> paddings_shape,std::initializer_list<int> paddings,T1 constant_values,const TensorData & output)5082   PadV2OpConstModel(const TensorData& input,
5083                     std::initializer_list<int> paddings_shape,
5084                     std::initializer_list<int> paddings, T1 constant_values,
5085                     const TensorData& output) {
5086     this->input_ = this->AddInput(input);
5087     this->paddings_ =
5088         this->AddConstInput(TensorType_INT32, paddings, paddings_shape);
5089     this->constant_values_ =
5090         this->AddConstInput(GetTensorType<T1>(), {constant_values}, {1});
5091 
5092     this->output_ = this->AddOutput(output);
5093 
5094     this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
5095                        CreatePadV2Options(this->builder_).Union());
5096     this->BuildInterpreterWithNNAPI({input.shape});
5097   }
5098 
PadV2OpConstModel(const TensorData & input,std::initializer_list<int> paddings_shape,std::initializer_list<int> paddings,const TensorData & constant_values,const TensorData & output)5099   PadV2OpConstModel(const TensorData& input,
5100                     std::initializer_list<int> paddings_shape,
5101                     std::initializer_list<int> paddings,
5102                     const TensorData& constant_values,
5103                     const TensorData& output) {
5104     this->input_ = this->AddInput(input);
5105     this->paddings_ =
5106         this->AddConstInput(TensorType_INT32, paddings, paddings_shape);
5107     this->constant_values_ = this->AddInput(constant_values);
5108 
5109     this->output_ = this->AddOutput(output);
5110 
5111     this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
5112                        CreatePadV2Options(this->builder_).Union());
5113     this->BuildInterpreterWithNNAPI({input.shape});
5114   }
5115 };
5116 
5117 // Test case where paddings is a non-const tensor.
5118 template <typename RegularInputOutput>
5119 class PadV2OpDynamicModel : public PadOpModel<RegularInputOutput> {
5120  public:
PadV2OpDynamicModel(const TensorData & input,std::initializer_list<int> paddings_shape,RegularInputOutput constant_values,const TensorData & output)5121   PadV2OpDynamicModel(const TensorData& input,
5122                       std::initializer_list<int> paddings_shape,
5123                       RegularInputOutput constant_values,
5124                       const TensorData& output) {
5125     this->input_ = this->AddInput(input);
5126     this->paddings_ = this->AddInput(TensorType_INT32);
5127     this->constant_values_ = this->AddConstInput(
5128         GetTensorType<RegularInputOutput>(), {constant_values}, {1});
5129     this->output_ = this->AddOutput(output);
5130 
5131     this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
5132                        CreatePadV2Options(this->builder_).Union());
5133     this->BuildInterpreterWithNNAPI({input.shape, paddings_shape});
5134   }
PadV2OpDynamicModel(const TensorData & input,std::initializer_list<int> paddings_shape,const TensorData & constant_values,const TensorData & output)5135   PadV2OpDynamicModel(const TensorData& input,
5136                       std::initializer_list<int> paddings_shape,
5137                       const TensorData& constant_values,
5138                       const TensorData& output) {
5139     this->input_ = this->AddInput(input);
5140     this->paddings_ = this->AddInput(TensorType_INT32);
5141     this->constant_values_ = this->AddInput(constant_values);
5142     this->output_ = this->AddOutput(output);
5143 
5144     this->SetBuiltinOp(BuiltinOperator_PADV2, BuiltinOptions_PadV2Options,
5145                        CreatePadV2Options(this->builder_).Union());
5146     this->BuildInterpreterWithNNAPI({input.shape, paddings_shape});
5147   }
5148 };
5149 
TEST(PadV2OpTest,SimpleConstTest)5150 TEST(PadV2OpTest, SimpleConstTest) {
5151   // Padding is represented as four 2-D lists representing above padding and
5152   // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
5153   PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
5154                              {0, 0, 1, 1, 1, 1, 0, 0}, 0.0,
5155                              {TensorType_FLOAT32});
5156   m.SetInput({1, 2, 3, 4});
5157   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5158   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({0, 0, 0, 0, 0, 1, 2, 0, 0, 3,
5159                                                   4, 0, 0, 0, 0, 0}));
5160   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5161 }
5162 
TEST(PadV2OpTest,SimpleConstFloat32ValuedTestUint8)5163 TEST(PadV2OpTest, SimpleConstFloat32ValuedTestUint8) {
5164   // Padding is represented as four 2-D lists representing above padding and
5165   // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
5166   PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
5167                              {0, 0, 1, 1, 1, 1, 0, 0}, 5, {TensorType_FLOAT32});
5168   m.SetInput({1, 2, 3, 4});
5169   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5170   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({5, 5, 5, 5, 5, 1, 2, 5, 5, 3,
5171                                                   4, 5, 5, 5, 5, 5}));
5172   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5173 }
5174 
TEST(PadV2OpTest,Simple4DConstFloat32ValuedTest)5175 TEST(PadV2OpTest, Simple4DConstFloat32ValuedTest) {
5176   // Padding is represented as four 2-D lists representing above padding and
5177   // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
5178   PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 1, 2, 1}}, {4, 2},
5179                              {0, 1, 0, 0, 0, 0, 0, 1}, 5, {TensorType_FLOAT32});
5180   m.SetInput({3, 3});
5181   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5182   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({3, 5, 3, 5, 5, 5, 5, 5}));
5183   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 2}));
5184 }
5185 
TEST(PadV2OpTest,SimpleDynamicTest)5186 TEST(PadV2OpTest, SimpleDynamicTest) {
5187   PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, 0.0,
5188                                {TensorType_FLOAT32});
5189   m.SetInput({1, 2, 3, 4});
5190   m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
5191   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5192   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({0, 0, 0, 0, 0, 1, 2, 0, 0, 3,
5193                                                   4, 0, 0, 0, 0, 0}));
5194   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5195 }
5196 
TEST(PadV2OpTest,SimpleDynamicValuedTest)5197 TEST(PadV2OpTest, SimpleDynamicValuedTest) {
5198   PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, 5,
5199                                {TensorType_FLOAT32});
5200   m.SetInput({1, 2, 3, 4});
5201   m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
5202   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5203   EXPECT_THAT(m.GetOutput(), NnapiArrayFloatNear({5, 5, 5, 5, 5, 1, 2, 5, 5, 3,
5204                                                   4, 5, 5, 5, 5, 5}));
5205   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5206 }
5207 
TEST(PadV2OpTest,AdvancedConstTest)5208 TEST(PadV2OpTest, AdvancedConstTest) {
5209   PadV2OpConstModel<float> m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2},
5210                              {0, 0, 0, 2, 1, 3, 0, 0}, 0, {TensorType_FLOAT32});
5211   m.SetInput({1, 2, 3, 4, 5, 6});
5212   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5213   EXPECT_THAT(m.GetOutput(),
5214               NnapiArrayFloatNear({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
5215                                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
5216   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
5217 }
5218 
TEST(PadV2OpTest,AdvancedDynamicTest)5219 TEST(PadV2OpTest, AdvancedDynamicTest) {
5220   PadV2OpDynamicModel<float> m({TensorType_FLOAT32, {1, 2, 3, 1}}, {4, 2}, 0,
5221                                {TensorType_FLOAT32});
5222   m.SetInput({1, 2, 3, 4, 5, 6});
5223   m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
5224   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5225   EXPECT_THAT(m.GetOutput(),
5226               NnapiArrayFloatNear({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
5227                                    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
5228   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
5229 }
5230 
DequantizedArrayNear(const std::vector<float> & values,const float min,const float max)5231 std::vector<testing::Matcher<float>> DequantizedArrayNear(
5232     const std::vector<float>& values, const float min, const float max) {
5233   const float quantization_tolerance = (max - min) / 255.0;
5234   return ArrayFloatNear(values, quantization_tolerance);
5235 }
5236 
5237 template <typename integer_type, TensorType tensor_dtype>
SimpleConstTestV2()5238 void SimpleConstTestV2() {
5239   // Padding is represented as four 2-D lists representing above padding and
5240   // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
5241   PadV2OpConstModel<integer_type> m(
5242       {tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0},
5243       {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0});
5244   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7});
5245   m.template SetQuantizedPadValue<integer_type>(0);
5246   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5247   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5248               ElementsAreArray(DequantizedArrayNear(
5249                   {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0},
5250                   -1.0, 1.0)));
5251   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5252 }
5253 
TEST(QuantizedPadV2OpTest,UInt8SimpleConstTest)5254 TEST(QuantizedPadV2OpTest, UInt8SimpleConstTest) {
5255   SimpleConstTestV2<uint8_t, TensorType_UINT8>();
5256 }
TEST(QuantizedPadV2OpTest,Int8SimpleConstTest)5257 TEST(QuantizedPadV2OpTest, Int8SimpleConstTest) {
5258   SimpleConstTestV2<int8_t, TensorType_INT8>();
5259 }
5260 
5261 template <typename integer_type, TensorType tensor_dtype>
SimpleDynamicTestV2()5262 void SimpleDynamicTestV2() {
5263   PadV2OpDynamicModel<integer_type> m({tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0},
5264                                       {4, 2}, {tensor_dtype, {1}, -1.0, 1.0},
5265                                       {tensor_dtype, {}, -1.0, 1.0});
5266   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7});
5267   m.template SetQuantizedPadValue<integer_type>(0);
5268   m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
5269   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5270   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5271               ElementsAreArray(DequantizedArrayNear(
5272                   {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0},
5273                   -1.0, 1.0)));
5274   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5275 }
5276 
TEST(QuantizedPadV2OpTest,UInt8SimpleDynamicTest)5277 TEST(QuantizedPadV2OpTest, UInt8SimpleDynamicTest) {
5278   SimpleDynamicTestV2<uint8_t, TensorType_UINT8>();
5279 }
TEST(QuantizedPadV2OpTest,Int8SimpleDynamicTest)5280 TEST(QuantizedPadV2OpTest, Int8SimpleDynamicTest) {
5281   SimpleDynamicTestV2<int8_t, TensorType_INT8>();
5282 }
5283 
5284 template <typename integer_type, TensorType tensor_dtype>
AdvancedConstTestV2()5285 void AdvancedConstTestV2() {
5286   PadV2OpConstModel<integer_type> m(
5287       {tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0},
5288       {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0});
5289   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
5290   m.template SetQuantizedPadValue<integer_type>(0);
5291   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5292   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5293               ElementsAreArray(DequantizedArrayNear(
5294                   {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0,
5295                    0, 0,    0,   0,   0, 0, 0, 0, 0,   0,   0,    0, 0, 0},
5296                   -1.0, 1.0)));
5297   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
5298 }
5299 
TEST(QuantizedPadV2OpTest,UInt8AdvancedConstTest)5300 TEST(QuantizedPadV2OpTest, UInt8AdvancedConstTest) {
5301   AdvancedConstTestV2<uint8_t, TensorType_UINT8>();
5302 }
TEST(QuantizedPadV2OpTest,Int8AdvancedConstTest)5303 TEST(QuantizedPadV2OpTest, Int8AdvancedConstTest) {
5304   AdvancedConstTestV2<int8_t, TensorType_INT8>();
5305 }
5306 
5307 template <typename integer_type, TensorType tensor_dtype>
AdvancedDynamicTestV2()5308 void AdvancedDynamicTestV2() {
5309   PadV2OpDynamicModel<integer_type> m({tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0},
5310                                       {4, 2}, {tensor_dtype, {1}, -1.0, 1.0},
5311                                       {tensor_dtype, {}, -1.0, 1.0});
5312   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
5313   m.template SetQuantizedPadValue<integer_type>(0);
5314   m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
5315   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5316   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5317               ElementsAreArray(DequantizedArrayNear(
5318                   {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0,
5319                    0, 0,    0,   0,   0, 0, 0, 0, 0,   0,   0,    0, 0, 0},
5320                   -1.0, 1.0)));
5321   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
5322 }
5323 
TEST(QuantizedPadV2OpTest,UInt8AdvancedDynamicTest)5324 TEST(QuantizedPadV2OpTest, UInt8AdvancedDynamicTest) {
5325   AdvancedDynamicTestV2<uint8_t, TensorType_UINT8>();
5326 }
TEST(QuantizedPadV2OpTest,Int8AdvancedDynamicTest)5327 TEST(QuantizedPadV2OpTest, Int8AdvancedDynamicTest) {
5328   AdvancedDynamicTestV2<int8_t, TensorType_INT8>();
5329 }
5330 
5331 template <typename integer_type, TensorType tensor_dtype>
SimpleConstValuedTest()5332 void SimpleConstValuedTest() {
5333   // Padding is represented as four 2-D lists representing above padding and
5334   // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
5335   PadV2OpConstModel<integer_type> m(
5336       {tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0},
5337       {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0});
5338   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7});
5339   m.template SetQuantizedPadValue<integer_type>(-0.5);
5340   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5341   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5342               ElementsAreArray(DequantizedArrayNear(
5343                   {-0.5, -0.5, -0.5, -0.5, -0.5, -0.8, 0.2, -0.5, -0.5, 0.9,
5344                    0.7, -0.5, -0.5, -0.5, -0.5, -0.5},
5345                   -1.0, 1.0)));
5346   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5347 }
5348 
TEST(QuantizedPadV2OpTest,UInt8SimpleConstValuedTest)5349 TEST(QuantizedPadV2OpTest, UInt8SimpleConstValuedTest) {
5350   SimpleConstValuedTest<uint8_t, TensorType_UINT8>();
5351 }
TEST(QuantizedPadV2OpTest,Int8SimpleConstValuedTest)5352 TEST(QuantizedPadV2OpTest, Int8SimpleConstValuedTest) {
5353   SimpleConstValuedTest<int8_t, TensorType_INT8>();
5354 }
5355 
5356 template <typename integer_type, TensorType tensor_dtype>
SimpleDynamicValuedTest()5357 void SimpleDynamicValuedTest() {
5358   PadV2OpDynamicModel<integer_type> m({tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0},
5359                                       {4, 2}, {tensor_dtype, {1}, -1.0, 1.0},
5360                                       {tensor_dtype, {}, -1.0, 1.0});
5361   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7});
5362   m.template SetQuantizedPadValue<integer_type>(-0.5);
5363   m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
5364   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5365   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5366               ElementsAreArray(DequantizedArrayNear(
5367                   {-0.5, -0.5, -0.5, -0.5, -0.5, -0.8, 0.2, -0.5, -0.5, 0.9,
5368                    0.7, -0.5, -0.5, -0.5, -0.5, -0.5},
5369                   -1.0, 1.0)));
5370   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
5371 }
5372 
TEST(QuantizedPadV2OpTest,UInt8SimpleDynamicValuedTest)5373 TEST(QuantizedPadV2OpTest, UInt8SimpleDynamicValuedTest) {
5374   SimpleDynamicValuedTest<uint8_t, TensorType_UINT8>();
5375 }
TEST(QuantizedPadV2OpTest,Int8SimpleDynamicValuedTest)5376 TEST(QuantizedPadV2OpTest, Int8SimpleDynamicValuedTest) {
5377   SimpleDynamicValuedTest<int8_t, TensorType_INT8>();
5378 }
5379 
5380 template <typename integer_type, TensorType tensor_dtype>
AdvancedConstValuedTest()5381 void AdvancedConstValuedTest() {
5382   PadV2OpConstModel<integer_type> m(
5383       {tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0},
5384       {tensor_dtype, {1}, -1.0, 1.0}, {tensor_dtype, {}, -1.0, 1.0});
5385   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
5386   m.template SetQuantizedPadValue<integer_type>(-0.5);
5387   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5388   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5389               ElementsAreArray(DequantizedArrayNear(
5390                   {-0.5, -0.8, 0.2,  0.9,  -0.5, -0.5, -0.5, -0.5, 0.7,  0.1,
5391                    -0.3, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5,
5392                    -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5},
5393                   -1.0, 1.0)));
5394   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
5395 }
5396 
TEST(QuantizedPadV2OpTest,UInt8AdvancedConstValuedTest)5397 TEST(QuantizedPadV2OpTest, UInt8AdvancedConstValuedTest) {
5398   AdvancedConstValuedTest<uint8_t, TensorType_UINT8>();
5399 }
TEST(QuantizedPadV2OpTest,Int8AdvancedConstValuedTest)5400 TEST(QuantizedPadV2OpTest, Int8AdvancedConstValuedTest) {
5401   AdvancedConstValuedTest<int8_t, TensorType_INT8>();
5402 }
5403 
5404 template <typename integer_type, TensorType tensor_dtype>
AdvancedDynamicValuedTest()5405 void AdvancedDynamicValuedTest() {
5406   PadV2OpDynamicModel<integer_type> m({tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0},
5407                                       {4, 2}, {tensor_dtype, {1}, -1.0, 1.0},
5408                                       {tensor_dtype, {}, -1.0, 1.0});
5409   m.template SetQuantizedInput<integer_type>({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3});
5410   m.template SetQuantizedPadValue<integer_type>(-0.5);
5411   m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
5412   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5413   EXPECT_THAT(m.template GetDequantizedOutput<integer_type>(),
5414               ElementsAreArray(DequantizedArrayNear(
5415                   {-0.5, -0.8, 0.2,  0.9,  -0.5, -0.5, -0.5, -0.5, 0.7,  0.1,
5416                    -0.3, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5,
5417                    -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5},
5418                   -1.0, 1.0)));
5419   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
5420 }
5421 
TEST(QuantizedPadV2OpTest,UInt8AdvancedDynamicValuedTest)5422 TEST(QuantizedPadV2OpTest, UInt8AdvancedDynamicValuedTest) {
5423   AdvancedDynamicValuedTest<uint8_t, TensorType_UINT8>();
5424 }
TEST(QuantizedPadV2OpTest,Int8AdvancedDynamicValuedTest)5425 TEST(QuantizedPadV2OpTest, Int8AdvancedDynamicValuedTest) {
5426   AdvancedDynamicValuedTest<int8_t, TensorType_INT8>();
5427 }
5428 
5429 // A base class of Leaky ReLU op model. It provides the constructor for
5430 // FloatLeakyReluOpModel and QuantizedLeakyReluOpModel.
5431 class LeakyReluOpModel : public SingleOpModelWithNNAPI {
5432  public:
LeakyReluOpModel(const TensorData & input,const float alpha)5433   LeakyReluOpModel(const TensorData& input, const float alpha)
5434       : input_type_(input.type) {
5435     input_ = AddInput(input);
5436     output_ = AddOutput({input.type, input.shape, input.min, input.max});
5437 
5438     SetBuiltinOp(BuiltinOperator_LEAKY_RELU, BuiltinOptions_LeakyReluOptions,
5439                  CreateLeakyReluOptions(builder_, alpha).Union());
5440     BuildInterpreterWithNNAPI({GetShape(input_)});
5441   }
5442 
SetInput(std::initializer_list<float> data)5443   void SetInput(std::initializer_list<float> data) {
5444     SetData(input_, input_type_, data);
5445   }
5446 
GetOutput()5447   std::vector<float> GetOutput() {
5448     std::vector<float> output;
5449     GetData(output_, input_type_, &output);
5450     return output;
5451   }
5452 
5453  protected:
5454   int input_;
5455   int output_;
5456 
5457   const TensorType input_type_;
5458 };
5459 
TEST(NNAPIDelegate,LeakyReluFloat)5460 TEST(NNAPIDelegate, LeakyReluFloat) {
5461   LeakyReluOpModel m({TensorType_FLOAT32, {2, 3}}, 0.5f);
5462 
5463   m.SetInput({
5464       0.0f, 1.0f, 3.0f,    // Row 1
5465       1.0f, -1.0f, -2.0f,  // Row 2
5466   });
5467   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5468   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
5469                                  0.0f, 1.0f, 3.0f,    // Row 1
5470                                  1.0f, -0.5f, -1.0f,  // Row 2
5471 
5472                              }));
5473 }
5474 
TEST(NNAPIDelegate,LeakyReluQuantized)5475 TEST(NNAPIDelegate, LeakyReluQuantized) {
5476   const float kMin = -1;
5477   const float kMax = 127.f / 128.f;
5478   LeakyReluOpModel m({TensorType_UINT8, {2, 3}, 8 * kMin, 8 * kMax}, 0.5f);
5479   m.SetInput({
5480       0.0f, 1.0f, 3.0f,    // Row 1
5481       1.0f, -1.0f, -2.0f,  // Row 2
5482   });
5483   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5484   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
5485                                  {
5486                                      0.0f, 1.0f, 3.0f,    // Row 1
5487                                      1.0f, -0.5f, -1.0f,  // Row 2
5488                                  },
5489                                  kQuantizedTolerance)));
5490 }
5491 }  // namespace
5492 
5493 namespace ops {
5494 namespace builtin {
5495 TfLiteRegistration* Register_FLOOR();
5496 }  // namespace builtin
5497 }  // namespace ops
5498 
5499 namespace {
5500 
GetNNAPIDimensions(const TfLiteTensor * tensor)5501 std::vector<uint32_t> GetNNAPIDimensions(const TfLiteTensor* tensor) {
5502   std::vector<uint32_t> dimensions;
5503   dimensions.reserve(tensor->dims->size);
5504   if (tensor->dims_signature != nullptr &&
5505       tensor->dims_signature->size == tensor->dims->size) {
5506     for (auto d : TfLiteIntArrayView(tensor->dims_signature)) {
5507       uint32_t nnapi_dim = (d == -1) ? 0 : static_cast<uint32_t>(d);
5508       dimensions.push_back(nnapi_dim);
5509     }
5510   } else {
5511     dimensions.assign(tensor->dims->data,
5512                       tensor->dims->data + tensor->dims->size);
5513   }
5514   return dimensions;
5515 }
5516 
5517 // The "nnapi-custom-op" is just float32 floor.
5518 static const char kTestCustomOp[] = "nnapi-custom-op";
5519 class NnapiTestVendorPlugin : public NnapiDelegateVendorPlugin {
5520  public:
NnapiTestVendorPlugin()5521   NnapiTestVendorPlugin() {
5522     ValidateNode = DoValidateNode;
5523     MapNode = DoMapNode;
5524     ConfigureCompilationHints = DoConfigureCompilationHints;
5525     ConfigureExecutionHints = DoConfigureExecutionHints;
5526   }
5527 
DoValidateNode(const TfLiteContext * context,const TfLiteRegistration * registration,const TfLiteNode * node)5528   static bool DoValidateNode(const TfLiteContext* context,
5529                              const TfLiteRegistration* registration,
5530                              const TfLiteNode* node) {
5531     if (strcmp(kTestCustomOp, registration->custom_name) != 0) {
5532       return false;
5533     }
5534     if (node->inputs->size != 1 || node->outputs->size != 1) {
5535       return false;
5536     }
5537     if (context->tensors[node->inputs->data[(0)]].type != kTfLiteFloat32 ||
5538         context->tensors[node->outputs->data[(0)]].type != kTfLiteFloat32) {
5539       return false;
5540     }
5541     return true;
5542   }
5543 
AddFloat32Tensor(const TfLiteContext * context,int tensor_index,NnapiMappingUtilCInterface * mapping,std::vector<uint32_t> * indices,ANeuralNetworksModel * model)5544   static TfLiteStatus AddFloat32Tensor(const TfLiteContext* context,
5545                                        int tensor_index,
5546                                        NnapiMappingUtilCInterface* mapping,
5547                                        std::vector<uint32_t>* indices,
5548                                        ANeuralNetworksModel* model) {
5549     int ann_tensor_index = mapping->TfLiteIndexToNnIndex(mapping, tensor_index);
5550     if (ann_tensor_index != -1) {
5551       indices->push_back(ann_tensor_index);
5552       return kTfLiteOk;
5553     }
5554     // Allocate a new tensor index
5555     ann_tensor_index = mapping->AddNewNnTensorIndex(mapping, tensor_index);
5556     TfLiteTensor* tensor = &context->tensors[tensor_index];
5557     auto dimensions = GetNNAPIDimensions(tensor);
5558     ANeuralNetworksOperandType operand_type{
5559         .type = ANEURALNETWORKS_TENSOR_FLOAT32,
5560         .dimensionCount = static_cast<uint32_t>(dimensions.size()),
5561         .dimensions = dimensions.data(),
5562         .scale = 0.0f,
5563         .zeroPoint = 0,
5564     };
5565     EXPECT_EQ(NnApiImplementation()->ANeuralNetworksModel_addOperand(
5566                   model, &operand_type),
5567               ANEURALNETWORKS_NO_ERROR);
5568     if (tensor->allocation_type == kTfLiteMmapRo) {
5569       EXPECT_EQ(NnApiImplementation()->ANeuralNetworksModel_setOperandValue(
5570                     model, ann_tensor_index, tensor->data.data, tensor->bytes),
5571                 ANEURALNETWORKS_NO_ERROR);
5572     }
5573     indices->push_back(ann_tensor_index);
5574     return kTfLiteOk;
5575   }
5576 
DoMapNode(TfLiteContext * context,const TfLiteNode * node,int node_index,NnapiMappingUtilCInterface * mapping,ANeuralNetworksModel * model)5577   static TfLiteStatus DoMapNode(TfLiteContext* context, const TfLiteNode* node,
5578                                 int node_index,
5579                                 NnapiMappingUtilCInterface* mapping,
5580                                 ANeuralNetworksModel* model) {
5581     std::vector<uint32_t> input_indices;
5582     std::vector<uint32_t> output_indices;
5583     for (int input_pos = 0; input_pos < node->inputs->size; ++input_pos) {
5584       const auto input_index = node->inputs->data[input_pos];
5585       EXPECT_EQ(AddFloat32Tensor(context, input_index, mapping, &input_indices,
5586                                  model),
5587                 kTfLiteOk);
5588     }
5589     for (int output_pos = 0; output_pos < node->outputs->size; ++output_pos) {
5590       const auto output_index = node->outputs->data[output_pos];
5591       EXPECT_EQ(AddFloat32Tensor(context, output_index, mapping,
5592                                  &output_indices, model),
5593                 kTfLiteOk);
5594     }
5595     EXPECT_EQ(
5596         NnApiImplementation()->ANeuralNetworksModel_addOperation(
5597             model, ANEURALNETWORKS_FLOOR,
5598             static_cast<uint32_t>(input_indices.size()), input_indices.data(),
5599             static_cast<uint32_t>(output_indices.size()),
5600             output_indices.data()),
5601         ANEURALNETWORKS_NO_ERROR);
5602     mapping->AddNnapiToTfliteOpMapping(mapping, node_index);
5603     return kTfLiteOk;
5604   }
5605 
DoConfigureCompilationHints(const char * compilation_hints,ANeuralNetworksCompilation * compilation)5606   static TfLiteStatus DoConfigureCompilationHints(
5607       const char* compilation_hints, ANeuralNetworksCompilation* compilation) {
5608     return kTfLiteOk;
5609   }
5610 
DoConfigureExecutionHints(const char * execution_hints,ANeuralNetworksExecution * execution)5611   static TfLiteStatus DoConfigureExecutionHints(
5612       const char* execution_hints, ANeuralNetworksExecution* execution) {
5613     return kTfLiteOk;
5614   }
5615 };
5616 
5617 class CustomFloorOpModel : public SingleOpModelWithNNAPI {
5618  public:
CustomFloorOpModel(const StatefulNnApiDelegate::Options & options,const TensorData & input,const TensorData & output,bool allow_fp32_relax_to_fp16=false,bool apply_delegate=true)5619   CustomFloorOpModel(const StatefulNnApiDelegate::Options& options,
5620                      const TensorData& input, const TensorData& output,
5621                      bool allow_fp32_relax_to_fp16 = false,
5622                      bool apply_delegate = true)
5623       : SingleOpModelWithNNAPI(options) {
5624     Init(input, output, allow_fp32_relax_to_fp16, apply_delegate);
5625   }
5626 
input()5627   int input() { return input_; }
output()5628   int output() { return output_; }
5629 
GetOutput()5630   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
5631 
5632  protected:
5633   int input_;
5634   int output_;
5635 
5636  private:
5637   // Performs initialization logic shared across all constructors.
Init(const TensorData & input,const TensorData & output,bool allow_fp32_relax_to_fp16=false,bool apply_delegate=true)5638   void Init(const TensorData& input, const TensorData& output,
5639             bool allow_fp32_relax_to_fp16 = false, bool apply_delegate = true) {
5640     input_ = AddInput(input);
5641     output_ = AddOutput(output);
5642     SetCustomOp(kTestCustomOp, {}, tflite::ops::builtin::Register_FLOOR);
5643     BuildInterpreterWithNNAPI({GetShape(input_)}, allow_fp32_relax_to_fp16,
5644                               apply_delegate);
5645   }
5646 };
5647 
TEST(NNAPIDelegate,CustomFloorVendorExtension)5648 TEST(NNAPIDelegate, CustomFloorVendorExtension) {
5649   auto vendor_plugin = std::make_unique<NnapiTestVendorPlugin>();
5650   StatefulNnApiDelegate::Options options;
5651   options.accelerator_name = "nnapi-reference";
5652   options.vendor_plugin = vendor_plugin.get();
5653 
5654   CustomFloorOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}},
5655                        {TensorType_FLOAT32, {1, 2, 2, 1}});
5656   m.PopulateTensor<float>(m.input(), {0, 0.2, 1.7, 2.8});
5657   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5658   EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.0, 1.0, 2.0}));
5659 }
5660 
TEST(NNAPIDelegate,CustomFloorVendorExtensionDynamic)5661 TEST(NNAPIDelegate, CustomFloorVendorExtensionDynamic) {
5662   // Skip the test until b/243704946 is fixed.
5663   GTEST_SKIP();
5664   // Models with dynamic dimensions and vendor plugin is not supported before
5665   // NNAPI 1.2 (API level 29).
5666   if (NnApiImplementation()->android_sdk_version <
5667       delegate::nnapi::kMinSdkVersionForNNAPI12) {
5668     GTEST_SKIP();
5669   }
5670 
5671   auto vendor_plugin = std::make_unique<NnapiTestVendorPlugin>();
5672   StatefulNnApiDelegate::Options options;
5673   options.accelerator_name = "nnapi-reference";
5674   options.vendor_plugin = vendor_plugin.get();
5675   options.allow_dynamic_dimensions = true;
5676 
5677   // Both input and output tensors have dynamic batch.
5678   auto tensor_data = TensorData{TensorType_FLOAT32,
5679                                 /*shape=*/{1, 2, 2, 1},
5680                                 /*min=*/0.0f,
5681                                 /*max=*/0.0f,
5682                                 /*scale=*/0.0f,
5683                                 /*zero_point=*/0,
5684                                 /*per_channel_quantization=*/false,
5685                                 /*per_channel_quantization_scales=*/{},
5686                                 /*per_channel_quantization_offsets=*/{},
5687                                 /*channel_index=*/0,
5688                                 /*traversal_order=*/{},
5689                                 /*format=*/{},
5690                                 /*block_size=*/{},
5691                                 /*block_map=*/{},
5692                                 /*shape_signature=*/{-1, 2, 2, 1}};
5693   size_t max_batch_size = 2;
5694   size_t tensor_max_size = max_batch_size * 2 * 2 * 1 * sizeof(float);
5695   CustomFloorOpModel m(options, tensor_data, tensor_data,
5696                        /*allow_fp32_relax_to_fp16=*/false,
5697                        /*apply_delegate=*/false);
5698   m.SetTensorMaxSize(m.input(), tensor_max_size);
5699   m.SetTensorMaxSize(m.output(), tensor_max_size);
5700   m.ApplyNNAPIDelegate();
5701 
5702   // Try the max batch size.
5703   EXPECT_EQ(m.ResizeInputTensor(m.input(), {2, 2, 2, 1}), kTfLiteOk);
5704   EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
5705   m.PopulateTensor<float>(m.input(), {0, 0.2, 1.7, 2.8, 3.4, 4.1, 5.9, 6.3});
5706   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5707   EXPECT_THAT(m.GetOutput(),
5708               ElementsAreArray({0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0}));
5709 
5710   // Try another batch size.
5711   EXPECT_EQ(m.ResizeInputTensor(m.input(), {1, 2, 2, 1}), kTfLiteOk);
5712   EXPECT_EQ(m.AllocateTensors(), kTfLiteOk);
5713   m.PopulateTensor<float>(m.input(), {1.7, 2.8, 3.4, 4.1});
5714   ASSERT_EQ(m.Invoke(), kTfLiteOk);
5715   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1.0, 2.0, 3.0, 4.0}));
5716 }
5717 
5718 }  // namespace
5719 }  // namespace tflite
5720