xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/shim/tflite_tensor_view_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/shim/tflite_tensor_view.h"
16 
17 #include <cstdint>
18 #include <string>
19 #include <utility>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/interpreter.h"
24 #include "tensorflow/lite/kernels/shim/test_util.h"
25 #include "tensorflow/lite/string_util.h"
26 
27 namespace tflite {
28 namespace shim {
29 namespace {
30 
31 using ::testing::Eq;
32 
TEST(TfLiteTensorW,Bool)33 TEST(TfLiteTensorW, Bool) {
34   ::tflite::Interpreter interpreter;
35   interpreter.AddTensors(1);
36   interpreter.AllocateTensors();
37   auto* tflite_tensor = interpreter.tensor(0);
38   ReallocDynamicTensor<bool>({3, 2}, tflite_tensor);
39   tflite_tensor->name = "test_bool";
40   auto owned_tflite_tensor = UniqueTfLiteTensor(tflite_tensor);
41 
42   // Test move assignment
43   auto t_premove_or = TensorView::New(tflite_tensor);
44   ASSERT_TRUE(t_premove_or.ok()) << t_premove_or.status();
45   auto t = std::move(t_premove_or.value());
46 
47   auto data = t.Data<bool>();
48   for (int32_t i = 0; i < 3 * 2; ++i) data[i] = (i % 5 == 0);
49 
50   ASSERT_THAT(TfliteTensorDebugString(tflite_tensor),
51               Eq("[[1, 0], [0, 0], [0, 1]]"));
52 }
53 
54 template <typename IntType>
IntTest()55 void IntTest() {
56   ::tflite::Interpreter interpreter;
57   interpreter.AddTensors(1);
58   interpreter.AllocateTensors();
59   auto* tflite_tensor = interpreter.tensor(0);
60   ReallocDynamicTensor<IntType>({3, 2}, tflite_tensor);
61   tflite_tensor->name = "test_int";
62   auto owned_tflite_tensor = UniqueTfLiteTensor(tflite_tensor);
63 
64   // Test move assignment
65   auto t_premove_or = TensorView::New(tflite_tensor);
66   ASSERT_TRUE(t_premove_or.ok()) << t_premove_or.status();
67   auto t = std::move(t_premove_or.value());
68 
69   auto data = t.Data<IntType>();
70   for (int32_t i = 0; i < 3 * 2; ++i) data[i] = i;
71 
72   ASSERT_THAT(TfliteTensorDebugString(tflite_tensor),
73               Eq("[[0, 1], [2, 3], [4, 5]]"));
74 }
75 
TEST(TfLiteTensorW,Int8)76 TEST(TfLiteTensorW, Int8) { IntTest<int8_t>(); }
TEST(TfLiteTensorW,UInt8)77 TEST(TfLiteTensorW, UInt8) { IntTest<uint8_t>(); }
TEST(TfLiteTensorW,Int16)78 TEST(TfLiteTensorW, Int16) { IntTest<int16_t>(); }
TEST(TfLiteTensorW,Int32)79 TEST(TfLiteTensorW, Int32) { IntTest<int32_t>(); }
TEST(TfLiteTensorW,Int64)80 TEST(TfLiteTensorW, Int64) { IntTest<int64_t>(); }
81 
82 template <typename FloatType>
FloatTest()83 void FloatTest() {
84   ::tflite::Interpreter interpreter;
85   interpreter.AddTensors(1);
86   interpreter.AllocateTensors();
87   auto* tflite_tensor = interpreter.tensor(0);
88   ReallocDynamicTensor<FloatType>({3, 2}, tflite_tensor);
89   tflite_tensor->name = "test_float";
90   auto owned_tflite_tensor = UniqueTfLiteTensor(tflite_tensor);
91 
92   auto t_or = TensorView::New(tflite_tensor);
93   ASSERT_TRUE(t_or.ok()) << t_or.status();
94   auto& t = t_or.value();
95 
96   auto data = t.Data<FloatType>();
97   for (int32_t i = 0; i < 3 * 2; ++i) data[i] = static_cast<FloatType>(i) / 2.;
98 
99   ASSERT_THAT(TfliteTensorDebugString(tflite_tensor),
100               Eq("[[0, 0.5], [1, 1.5], [2, 2.5]]"));
101 }
102 
TEST(TfLiteTensorW,Float)103 TEST(TfLiteTensorW, Float) { FloatTest<float>(); }
TEST(TfLiteTensorW,Double)104 TEST(TfLiteTensorW, Double) { FloatTest<double>(); }
105 
TEST(TfLiteTensorW,Str)106 TEST(TfLiteTensorW, Str) {
107   ::tflite::Interpreter interpreter;
108   interpreter.AddTensors(1);
109   interpreter.AllocateTensors();
110   auto* tflite_tensor = interpreter.tensor(0);
111   ReallocDynamicTensor<std::string>({3, 2}, tflite_tensor);
112   tflite_tensor->name = "test_str";
113   auto owned_tflite_tensor = UniqueTfLiteTensor(tflite_tensor);
114 
115   {
116     auto t_or = TensorView::New(tflite_tensor);
117     ASSERT_TRUE(t_or.ok()) << t_or.status();
118     auto& t = t_or.value();
119     auto t_mat = t.As<::tensorflow::tstring, 2>();
120     t.Data<::tensorflow::tstring>()[0] = "a";
121     t.Data<::tensorflow::tstring>()[1] = "bc";
122     t_mat(1, 0) = "def";
123     t.Data<::tensorflow::tstring>()[3] = "g";
124     t.Data<::tensorflow::tstring>()[4] = "";
125     t_mat(2, 1) = "hi";
126   }
127 
128   {
129     auto t_or = TensorView::New(tflite_tensor);
130     ASSERT_TRUE(t_or.ok()) << t_or.status();
131     auto& t = t_or.value();
132     EXPECT_THAT(t.Data<::tensorflow::tstring>(),
133                 ::testing::ElementsAre("a", "bc", "def", "g", "", "hi"));
134   }
135 
136   const auto const_tflite_tensor = tflite_tensor;
137   {
138     const auto t_or = TensorView::New(const_tflite_tensor);
139     ASSERT_TRUE(t_or.ok()) << t_or.status();
140     const auto& t = t_or.value();
141     EXPECT_THAT(t.Data<::tensorflow::tstring>(),
142                 ::testing::ElementsAre("a", "bc", "def", "g", "", "hi"));
143   }
144 
145   EXPECT_THAT(TfliteTensorDebugString(tflite_tensor),
146               Eq("[[a, bc], [def, g], [, hi]]"));
147 }
148 
TEST(TfLiteTensorW,EmptyStr)149 TEST(TfLiteTensorW, EmptyStr) {
150   ::tflite::Interpreter interpreter;
151   interpreter.AddTensors(1);
152   interpreter.AllocateTensors();
153   auto* tflite_tensor = interpreter.tensor(0);
154   ReallocDynamicTensor<std::string>(/*shape=*/{0}, tflite_tensor);
155   tflite_tensor->name = "test_str";
156   auto owned_tflite_tensor = UniqueTfLiteTensor(tflite_tensor);
157 
158   // Placing tensor_view instance in a block to ensure its dtor runs
159   {
160     auto t_or = TensorView::New(tflite_tensor);
161     ASSERT_TRUE(t_or.ok()) << t_or.status();
162   }
163 
164   EXPECT_THAT(GetStringCount(tflite_tensor), Eq(0));
165 }
166 
167 }  // namespace
168 }  // namespace shim
169 }  // namespace tflite
170