1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/lua-utils.h"
18
19 #include <memory>
20 #include <string>
21
22 #include "utils/flatbuffers/flatbuffers.h"
23 #include "utils/flatbuffers/mutable.h"
24 #include "utils/lua_utils_tests_generated.h"
25 #include "utils/strings/stringpiece.h"
26 #include "utils/test-data-test-utils.h"
27 #include "utils/testing/test_data_generator.h"
28 #include "gmock/gmock.h"
29 #include "gtest/gtest.h"
30
31 namespace libtextclassifier3 {
32 namespace {
33
34 using testing::DoubleEq;
35 using testing::ElementsAre;
36 using testing::Eq;
37 using testing::FloatEq;
38
39 class LuaUtilsTest : public testing::Test, protected LuaEnvironment {
40 protected:
LuaUtilsTest()41 LuaUtilsTest()
42 : schema_(GetTestFileContent("utils/lua_utils_tests.bfbs")),
43 flatbuffer_builder_(schema_.get()) {
44 EXPECT_THAT(RunProtected([this] {
45 LoadDefaultLibraries();
46 return LUA_OK;
47 }),
48 Eq(LUA_OK));
49 }
50
RunScript(StringPiece script)51 void RunScript(StringPiece script) {
52 EXPECT_THAT(luaL_loadbuffer(state_, script.data(), script.size(),
53 /*name=*/nullptr),
54 Eq(LUA_OK));
55 EXPECT_THAT(
56 lua_pcall(state_, /*nargs=*/0, /*num_results=*/1, /*errfunc=*/0),
57 Eq(LUA_OK));
58 }
59
60 OwnedFlatbuffer<reflection::Schema, std::string> schema_;
61 MutableFlatbufferBuilder flatbuffer_builder_;
62 TestDataGenerator test_data_generator_;
63 };
64
65 template <typename T>
66 class TypedLuaUtilsTest : public LuaUtilsTest {};
67
68 using testing::Types;
69 using LuaTypes =
70 ::testing::Types<int64, uint64, int32, uint32, int16, uint16, int8, uint8,
71 float, double, bool, std::string>;
72 TYPED_TEST_SUITE(TypedLuaUtilsTest, LuaTypes);
73
TYPED_TEST(TypedLuaUtilsTest,HandlesVectors)74 TYPED_TEST(TypedLuaUtilsTest, HandlesVectors) {
75 std::vector<TypeParam> elements(5);
76 std::generate_n(elements.begin(), 5, [&]() {
77 return this->test_data_generator_.template generate<TypeParam>();
78 });
79
80 this->PushVector(elements);
81
82 EXPECT_THAT(this->template ReadVector<TypeParam>(),
83 testing::ContainerEq(elements));
84 }
85
TYPED_TEST(TypedLuaUtilsTest,HandlesVectorIterators)86 TYPED_TEST(TypedLuaUtilsTest, HandlesVectorIterators) {
87 std::vector<TypeParam> elements(5);
88 std::generate_n(elements.begin(), 5, [&]() {
89 return this->test_data_generator_.template generate<TypeParam>();
90 });
91
92 this->PushVectorIterator(&elements);
93
94 EXPECT_THAT(this->template ReadVector<TypeParam>(),
95 testing::ContainerEq(elements));
96 }
97
TEST_F(LuaUtilsTest,IndexCallback)98 TEST_F(LuaUtilsTest, IndexCallback) {
99 test::TestDataT input_data;
100 input_data.repeated_byte_field = {1, 2};
101 input_data.repeated_ubyte_field = {1, 2};
102 input_data.repeated_int_field = {1, 2};
103 input_data.repeated_uint_field = {1, 2};
104 input_data.repeated_long_field = {1, 2};
105 input_data.repeated_ulong_field = {1, 2};
106 input_data.repeated_bool_field = {true, false};
107 input_data.repeated_float_field = {1, 2};
108 input_data.repeated_double_field = {1, 2};
109 input_data.repeated_string_field = {"1", "2"};
110
111 flatbuffers::FlatBufferBuilder builder;
112 builder.Finish(test::TestData::Pack(builder, &input_data));
113 const flatbuffers::DetachedBuffer input_buffer = builder.Release();
114 PushFlatbuffer(schema_.get(),
115 flatbuffers::GetRoot<flatbuffers::Table>(input_buffer.data()));
116 lua_setglobal(state_, "arg");
117 // A Lua script that reads the vectors and return the first value of them.
118 // This should trigger the __index callback.
119 RunScript(R"lua(
120 return {
121 byte_field = arg.repeated_byte_field[1],
122 ubyte_field = arg.repeated_ubyte_field[1],
123 int_field = arg.repeated_int_field[1],
124 uint_field = arg.repeated_uint_field[1],
125 long_field = arg.repeated_long_field[1],
126 ulong_field = arg.repeated_ulong_field[1],
127 bool_field = arg.repeated_bool_field[1],
128 float_field = arg.repeated_float_field[1],
129 double_field = arg.repeated_double_field[1],
130 string_field = arg.repeated_string_field[1],
131 }
132 )lua");
133
134 // Read the flatbuffer.
135 std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
136 ReadFlatbuffer(/*index=*/-1, buffer.get());
137 const std::string serialized_buffer = buffer->Serialize();
138 std::unique_ptr<test::TestDataT> test_data =
139 LoadAndVerifyMutableFlatbuffer<test::TestData>(buffer->Serialize());
140
141 EXPECT_THAT(test_data->byte_field, 1);
142 EXPECT_THAT(test_data->ubyte_field, 1);
143 EXPECT_THAT(test_data->int_field, 1);
144 EXPECT_THAT(test_data->uint_field, 1);
145 EXPECT_THAT(test_data->long_field, 1);
146 EXPECT_THAT(test_data->ulong_field, 1);
147 EXPECT_THAT(test_data->bool_field, true);
148 EXPECT_THAT(test_data->float_field, FloatEq(1));
149 EXPECT_THAT(test_data->double_field, DoubleEq(1));
150 EXPECT_THAT(test_data->string_field, "1");
151 }
152
TEST_F(LuaUtilsTest,PairCallback)153 TEST_F(LuaUtilsTest, PairCallback) {
154 test::TestDataT input_data;
155 input_data.repeated_byte_field = {1, 2};
156 input_data.repeated_ubyte_field = {1, 2};
157 input_data.repeated_int_field = {1, 2};
158 input_data.repeated_uint_field = {1, 2};
159 input_data.repeated_long_field = {1, 2};
160 input_data.repeated_ulong_field = {1, 2};
161 input_data.repeated_bool_field = {true, false};
162 input_data.repeated_float_field = {1, 2};
163 input_data.repeated_double_field = {1, 2};
164 input_data.repeated_string_field = {"1", "2"};
165
166 flatbuffers::FlatBufferBuilder builder;
167 builder.Finish(test::TestData::Pack(builder, &input_data));
168 const flatbuffers::DetachedBuffer input_buffer = builder.Release();
169 PushFlatbuffer(schema_.get(),
170 flatbuffers::GetRoot<flatbuffers::Table>(input_buffer.data()));
171 lua_setglobal(state_, "arg");
172
173 // Iterate the pushed repeated fields by using the pair API and check
174 // if the value is correct. This should trigger the __pair callback.
175 RunScript(R"lua(
176 function equal(table1, table2)
177 for key, value in pairs(table1) do
178 if value ~= table2[key] then
179 return false
180 end
181 end
182 return true
183 end
184
185 local valid = equal(arg.repeated_byte_field, {[1]=1,[2]=2})
186 valid = valid and equal(arg.repeated_ubyte_field, {[1]=1,[2]=2})
187 valid = valid and equal(arg.repeated_int_field, {[1]=1,[2]=2})
188 valid = valid and equal(arg.repeated_uint_field, {[1]=1,[2]=2})
189 valid = valid and equal(arg.repeated_long_field, {[1]=1,[2]=2})
190 valid = valid and equal(arg.repeated_ulong_field, {[1]=1,[2]=2})
191 valid = valid and equal(arg.repeated_bool_field, {[1]=true,[2]=false})
192 valid = valid and equal(arg.repeated_float_field, {[1]=1,[2]=2})
193 valid = valid and equal(arg.repeated_double_field, {[1]=1,[2]=2})
194 valid = valid and equal(arg.repeated_string_field, {[1]="1",[2]="2"})
195
196 return {
197 bool_field = valid
198 }
199 )lua");
200
201 // Read the flatbuffer.
202 std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
203 ReadFlatbuffer(/*index=*/-1, buffer.get());
204 const std::string serialized_buffer = buffer->Serialize();
205 std::unique_ptr<test::TestDataT> test_data =
206 LoadAndVerifyMutableFlatbuffer<test::TestData>(buffer->Serialize());
207
208 EXPECT_THAT(test_data->bool_field, true);
209 }
210
TEST_F(LuaUtilsTest,PushAndReadsFlatbufferRoundTrip)211 TEST_F(LuaUtilsTest, PushAndReadsFlatbufferRoundTrip) {
212 // Setup.
213 test::TestDataT input_data;
214 input_data.byte_field = 1;
215 input_data.ubyte_field = 2;
216 input_data.int_field = 10;
217 input_data.uint_field = 11;
218 input_data.long_field = 20;
219 input_data.ulong_field = 21;
220 input_data.bool_field = true;
221 input_data.float_field = 42.1;
222 input_data.double_field = 12.4;
223 input_data.string_field = "hello there";
224 // Nested field.
225 input_data.nested_field = std::make_unique<test::TestDataT>();
226 input_data.nested_field->float_field = 64;
227 input_data.nested_field->string_field = "hello nested";
228 // Repeated fields.
229 input_data.repeated_byte_field = {1, 2, 1};
230 input_data.repeated_byte_field = {1, 2, 1};
231 input_data.repeated_ubyte_field = {2, 4, 2};
232 input_data.repeated_int_field = {1, 2, 3};
233 input_data.repeated_uint_field = {2, 4, 6};
234 input_data.repeated_long_field = {4, 5, 6};
235 input_data.repeated_ulong_field = {8, 10, 12};
236 input_data.repeated_bool_field = {true, false, true};
237 input_data.repeated_float_field = {1.23, 2.34, 3.45};
238 input_data.repeated_double_field = {1.11, 2.22, 3.33};
239 input_data.repeated_string_field = {"a", "bold", "one"};
240 // Repeated nested fields.
241 input_data.repeated_nested_field.push_back(
242 std::make_unique<test::TestDataT>());
243 input_data.repeated_nested_field.back()->string_field = "a";
244 input_data.repeated_nested_field.push_back(
245 std::make_unique<test::TestDataT>());
246 input_data.repeated_nested_field.back()->string_field = "b";
247 input_data.repeated_nested_field.push_back(
248 std::make_unique<test::TestDataT>());
249 input_data.repeated_nested_field.back()->repeated_string_field = {"nested",
250 "nested2"};
251 flatbuffers::FlatBufferBuilder builder;
252 builder.Finish(test::TestData::Pack(builder, &input_data));
253 const flatbuffers::DetachedBuffer input_buffer = builder.Release();
254 PushFlatbuffer(schema_.get(),
255 flatbuffers::GetRoot<flatbuffers::Table>(input_buffer.data()));
256 lua_setglobal(state_, "arg");
257
258 RunScript(R"lua(
259 return {
260 byte_field = arg.byte_field,
261 ubyte_field = arg.ubyte_field,
262 int_field = arg.int_field,
263 uint_field = arg.uint_field,
264 long_field = arg.long_field,
265 ulong_field = arg.ulong_field,
266 bool_field = arg.bool_field,
267 float_field = arg.float_field,
268 double_field = arg.double_field,
269 string_field = arg.string_field,
270 nested_field = {
271 float_field = arg.nested_field.float_field,
272 string_field = arg.nested_field.string_field,
273 },
274 repeated_byte_field = arg.repeated_byte_field,
275 repeated_ubyte_field = arg.repeated_ubyte_field,
276 repeated_int_field = arg.repeated_int_field,
277 repeated_uint_field = arg.repeated_uint_field,
278 repeated_long_field = arg.repeated_long_field,
279 repeated_ulong_field = arg.repeated_ulong_field,
280 repeated_bool_field = arg.repeated_bool_field,
281 repeated_float_field = arg.repeated_float_field,
282 repeated_double_field = arg.repeated_double_field,
283 repeated_string_field = arg.repeated_string_field,
284 repeated_nested_field = {
285 { string_field = arg.repeated_nested_field[1].string_field },
286 { string_field = arg.repeated_nested_field[2].string_field },
287 { repeated_string_field = arg.repeated_nested_field[3].repeated_string_field },
288 },
289 }
290 )lua");
291
292 // Read the flatbuffer.
293 std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
294 ReadFlatbuffer(/*index=*/-1, buffer.get());
295 const std::string serialized_buffer = buffer->Serialize();
296 std::unique_ptr<test::TestDataT> test_data =
297 LoadAndVerifyMutableFlatbuffer<test::TestData>(buffer->Serialize());
298
299 EXPECT_THAT(test_data->byte_field, 1);
300 EXPECT_THAT(test_data->ubyte_field, 2);
301 EXPECT_THAT(test_data->int_field, 10);
302 EXPECT_THAT(test_data->uint_field, 11);
303 EXPECT_THAT(test_data->long_field, 20);
304 EXPECT_THAT(test_data->ulong_field, 21);
305 EXPECT_THAT(test_data->bool_field, true);
306 EXPECT_THAT(test_data->float_field, FloatEq(42.1));
307 EXPECT_THAT(test_data->double_field, DoubleEq(12.4));
308 EXPECT_THAT(test_data->string_field, "hello there");
309 EXPECT_THAT(test_data->repeated_byte_field, ElementsAre(1, 2, 1));
310 EXPECT_THAT(test_data->repeated_ubyte_field, ElementsAre(2, 4, 2));
311 EXPECT_THAT(test_data->repeated_int_field, ElementsAre(1, 2, 3));
312 EXPECT_THAT(test_data->repeated_uint_field, ElementsAre(2, 4, 6));
313 EXPECT_THAT(test_data->repeated_long_field, ElementsAre(4, 5, 6));
314 EXPECT_THAT(test_data->repeated_ulong_field, ElementsAre(8, 10, 12));
315 EXPECT_THAT(test_data->repeated_bool_field, ElementsAre(true, false, true));
316 EXPECT_THAT(test_data->repeated_float_field, ElementsAre(1.23, 2.34, 3.45));
317 EXPECT_THAT(test_data->repeated_double_field, ElementsAre(1.11, 2.22, 3.33));
318 EXPECT_THAT(test_data->repeated_string_field,
319 ElementsAre("a", "bold", "one"));
320 // Nested fields.
321 EXPECT_THAT(test_data->nested_field->float_field, FloatEq(64));
322 EXPECT_THAT(test_data->nested_field->string_field, "hello nested");
323 // Repeated nested fields.
324 EXPECT_THAT(test_data->repeated_nested_field[0]->string_field, "a");
325 EXPECT_THAT(test_data->repeated_nested_field[1]->string_field, "b");
326 EXPECT_THAT(test_data->repeated_nested_field[2]->repeated_string_field,
327 ElementsAre("nested", "nested2"));
328 }
329
TEST_F(LuaUtilsTest,HandlesRepeatedNestedFlatbufferFields)330 TEST_F(LuaUtilsTest, HandlesRepeatedNestedFlatbufferFields) {
331 // Create test flatbuffer.
332 std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
333 RepeatedField* repeated_field = buffer->Repeated("repeated_nested_field");
334 repeated_field->Add()->Set("string_field", "hello");
335 repeated_field->Add()->Set("string_field", "my");
336 MutableFlatbuffer* nested = repeated_field->Add();
337 nested->Set("string_field", "old");
338 RepeatedField* nested_repeated = nested->Repeated("repeated_string_field");
339 nested_repeated->Add("friend");
340 nested_repeated->Add("how");
341 nested_repeated->Add("are");
342 repeated_field->Add()->Set("string_field", "you?");
343 const std::string serialized_buffer = buffer->Serialize();
344 PushFlatbuffer(schema_.get(), flatbuffers::GetRoot<flatbuffers::Table>(
345 serialized_buffer.data()));
346 lua_setglobal(state_, "arg");
347
348 RunScript(R"lua(
349 result = {}
350 for _, nested in pairs(arg.repeated_nested_field) do
351 result[#result + 1] = nested.string_field
352 for _, nested_string in pairs(nested.repeated_string_field) do
353 result[#result + 1] = nested_string
354 end
355 end
356 return result
357 )lua");
358
359 EXPECT_THAT(
360 ReadVector<std::string>(),
361 ElementsAre("hello", "my", "old", "friend", "how", "are", "you?"));
362 }
363
TEST_F(LuaUtilsTest,CorrectlyReadsTwoFlatbuffersSimultaneously)364 TEST_F(LuaUtilsTest, CorrectlyReadsTwoFlatbuffersSimultaneously) {
365 // The first flatbuffer.
366 std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
367 buffer->Set("string_field", "first");
368 const std::string serialized_buffer = buffer->Serialize();
369 PushFlatbuffer(schema_.get(), flatbuffers::GetRoot<flatbuffers::Table>(
370 serialized_buffer.data()));
371 lua_setglobal(state_, "arg");
372 // The second flatbuffer.
373 std::unique_ptr<MutableFlatbuffer> buffer2 = flatbuffer_builder_.NewRoot();
374 buffer2->Set("string_field", "second");
375 const std::string serialized_buffer2 = buffer2->Serialize();
376 PushFlatbuffer(schema_.get(), flatbuffers::GetRoot<flatbuffers::Table>(
377 serialized_buffer2.data()));
378 lua_setglobal(state_, "arg2");
379
380 RunScript(R"lua(
381 return {arg.string_field, arg2.string_field}
382 )lua");
383
384 EXPECT_THAT(ReadVector<std::string>(), ElementsAre("first", "second"));
385 }
386
387 } // namespace
388 } // namespace libtextclassifier3
389