1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/shim/tflite_tensor_view.h"
16
17 #include <cstdint>
18 #include <string>
19 #include <utility>
20
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "tensorflow/lite/interpreter.h"
24 #include "tensorflow/lite/kernels/shim/test_util.h"
25 #include "tensorflow/lite/string_util.h"
26
27 namespace tflite {
28 namespace shim {
29 namespace {
30
31 using ::testing::Eq;
32
TEST(TfLiteTensorW,Bool)33 TEST(TfLiteTensorW, Bool) {
34 ::tflite::Interpreter interpreter;
35 interpreter.AddTensors(1);
36 interpreter.AllocateTensors();
37 auto* tflite_tensor = interpreter.tensor(0);
38 ReallocDynamicTensor<bool>({3, 2}, tflite_tensor);
39 tflite_tensor->name = "test_bool";
40 auto owned_tflite_tensor = UniqueTfLiteTensor(tflite_tensor);
41
42 // Test move assignment
43 auto t_premove_or = TensorView::New(tflite_tensor);
44 ASSERT_TRUE(t_premove_or.ok()) << t_premove_or.status();
45 auto t = std::move(t_premove_or.value());
46
47 auto data = t.Data<bool>();
48 for (int32_t i = 0; i < 3 * 2; ++i) data[i] = (i % 5 == 0);
49
50 ASSERT_THAT(TfliteTensorDebugString(tflite_tensor),
51 Eq("[[1, 0], [0, 0], [0, 1]]"));
52 }
53
54 template <typename IntType>
IntTest()55 void IntTest() {
56 ::tflite::Interpreter interpreter;
57 interpreter.AddTensors(1);
58 interpreter.AllocateTensors();
59 auto* tflite_tensor = interpreter.tensor(0);
60 ReallocDynamicTensor<IntType>({3, 2}, tflite_tensor);
61 tflite_tensor->name = "test_int";
62 auto owned_tflite_tensor = UniqueTfLiteTensor(tflite_tensor);
63
64 // Test move assignment
65 auto t_premove_or = TensorView::New(tflite_tensor);
66 ASSERT_TRUE(t_premove_or.ok()) << t_premove_or.status();
67 auto t = std::move(t_premove_or.value());
68
69 auto data = t.Data<IntType>();
70 for (int32_t i = 0; i < 3 * 2; ++i) data[i] = i;
71
72 ASSERT_THAT(TfliteTensorDebugString(tflite_tensor),
73 Eq("[[0, 1], [2, 3], [4, 5]]"));
74 }
75
TEST(TfLiteTensorW,Int8)76 TEST(TfLiteTensorW, Int8) { IntTest<int8_t>(); }
TEST(TfLiteTensorW,UInt8)77 TEST(TfLiteTensorW, UInt8) { IntTest<uint8_t>(); }
TEST(TfLiteTensorW,Int16)78 TEST(TfLiteTensorW, Int16) { IntTest<int16_t>(); }
TEST(TfLiteTensorW,Int32)79 TEST(TfLiteTensorW, Int32) { IntTest<int32_t>(); }
TEST(TfLiteTensorW,Int64)80 TEST(TfLiteTensorW, Int64) { IntTest<int64_t>(); }
81
82 template <typename FloatType>
FloatTest()83 void FloatTest() {
84 ::tflite::Interpreter interpreter;
85 interpreter.AddTensors(1);
86 interpreter.AllocateTensors();
87 auto* tflite_tensor = interpreter.tensor(0);
88 ReallocDynamicTensor<FloatType>({3, 2}, tflite_tensor);
89 tflite_tensor->name = "test_float";
90 auto owned_tflite_tensor = UniqueTfLiteTensor(tflite_tensor);
91
92 auto t_or = TensorView::New(tflite_tensor);
93 ASSERT_TRUE(t_or.ok()) << t_or.status();
94 auto& t = t_or.value();
95
96 auto data = t.Data<FloatType>();
97 for (int32_t i = 0; i < 3 * 2; ++i) data[i] = static_cast<FloatType>(i) / 2.;
98
99 ASSERT_THAT(TfliteTensorDebugString(tflite_tensor),
100 Eq("[[0, 0.5], [1, 1.5], [2, 2.5]]"));
101 }
102
TEST(TfLiteTensorW,Float)103 TEST(TfLiteTensorW, Float) { FloatTest<float>(); }
TEST(TfLiteTensorW,Double)104 TEST(TfLiteTensorW, Double) { FloatTest<double>(); }
105
TEST(TfLiteTensorW,Str)106 TEST(TfLiteTensorW, Str) {
107 ::tflite::Interpreter interpreter;
108 interpreter.AddTensors(1);
109 interpreter.AllocateTensors();
110 auto* tflite_tensor = interpreter.tensor(0);
111 ReallocDynamicTensor<std::string>({3, 2}, tflite_tensor);
112 tflite_tensor->name = "test_str";
113 auto owned_tflite_tensor = UniqueTfLiteTensor(tflite_tensor);
114
115 {
116 auto t_or = TensorView::New(tflite_tensor);
117 ASSERT_TRUE(t_or.ok()) << t_or.status();
118 auto& t = t_or.value();
119 auto t_mat = t.As<::tensorflow::tstring, 2>();
120 t.Data<::tensorflow::tstring>()[0] = "a";
121 t.Data<::tensorflow::tstring>()[1] = "bc";
122 t_mat(1, 0) = "def";
123 t.Data<::tensorflow::tstring>()[3] = "g";
124 t.Data<::tensorflow::tstring>()[4] = "";
125 t_mat(2, 1) = "hi";
126 }
127
128 {
129 auto t_or = TensorView::New(tflite_tensor);
130 ASSERT_TRUE(t_or.ok()) << t_or.status();
131 auto& t = t_or.value();
132 EXPECT_THAT(t.Data<::tensorflow::tstring>(),
133 ::testing::ElementsAre("a", "bc", "def", "g", "", "hi"));
134 }
135
136 const auto const_tflite_tensor = tflite_tensor;
137 {
138 const auto t_or = TensorView::New(const_tflite_tensor);
139 ASSERT_TRUE(t_or.ok()) << t_or.status();
140 const auto& t = t_or.value();
141 EXPECT_THAT(t.Data<::tensorflow::tstring>(),
142 ::testing::ElementsAre("a", "bc", "def", "g", "", "hi"));
143 }
144
145 EXPECT_THAT(TfliteTensorDebugString(tflite_tensor),
146 Eq("[[a, bc], [def, g], [, hi]]"));
147 }
148
TEST(TfLiteTensorW,EmptyStr)149 TEST(TfLiteTensorW, EmptyStr) {
150 ::tflite::Interpreter interpreter;
151 interpreter.AddTensors(1);
152 interpreter.AllocateTensors();
153 auto* tflite_tensor = interpreter.tensor(0);
154 ReallocDynamicTensor<std::string>(/*shape=*/{0}, tflite_tensor);
155 tflite_tensor->name = "test_str";
156 auto owned_tflite_tensor = UniqueTfLiteTensor(tflite_tensor);
157
158 // Placing tensor_view instance in a block to ensure its dtor runs
159 {
160 auto t_or = TensorView::New(tflite_tensor);
161 ASSERT_TRUE(t_or.ok()) << t_or.status();
162 }
163
164 EXPECT_THAT(GetStringCount(tflite_tensor), Eq(0));
165 }
166
167 } // namespace
168 } // namespace shim
169 } // namespace tflite
170