xref: /aosp_15_r20/external/libtextclassifier/native/utils/lua-utils_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/lua-utils.h"
18 
19 #include <memory>
20 #include <string>
21 
22 #include "utils/flatbuffers/flatbuffers.h"
23 #include "utils/flatbuffers/mutable.h"
24 #include "utils/lua_utils_tests_generated.h"
25 #include "utils/strings/stringpiece.h"
26 #include "utils/test-data-test-utils.h"
27 #include "utils/testing/test_data_generator.h"
28 #include "gmock/gmock.h"
29 #include "gtest/gtest.h"
30 
31 namespace libtextclassifier3 {
32 namespace {
33 
34 using testing::DoubleEq;
35 using testing::ElementsAre;
36 using testing::Eq;
37 using testing::FloatEq;
38 
39 class LuaUtilsTest : public testing::Test, protected LuaEnvironment {
40  protected:
LuaUtilsTest()41   LuaUtilsTest()
42       : schema_(GetTestFileContent("utils/lua_utils_tests.bfbs")),
43         flatbuffer_builder_(schema_.get()) {
44     EXPECT_THAT(RunProtected([this] {
45                   LoadDefaultLibraries();
46                   return LUA_OK;
47                 }),
48                 Eq(LUA_OK));
49   }
50 
RunScript(StringPiece script)51   void RunScript(StringPiece script) {
52     EXPECT_THAT(luaL_loadbuffer(state_, script.data(), script.size(),
53                                 /*name=*/nullptr),
54                 Eq(LUA_OK));
55     EXPECT_THAT(
56         lua_pcall(state_, /*nargs=*/0, /*num_results=*/1, /*errfunc=*/0),
57         Eq(LUA_OK));
58   }
59 
60   OwnedFlatbuffer<reflection::Schema, std::string> schema_;
61   MutableFlatbufferBuilder flatbuffer_builder_;
62   TestDataGenerator test_data_generator_;
63 };
64 
65 template <typename T>
66 class TypedLuaUtilsTest : public LuaUtilsTest {};
67 
68 using testing::Types;
69 using LuaTypes =
70     ::testing::Types<int64, uint64, int32, uint32, int16, uint16, int8, uint8,
71                      float, double, bool, std::string>;
72 TYPED_TEST_SUITE(TypedLuaUtilsTest, LuaTypes);
73 
TYPED_TEST(TypedLuaUtilsTest,HandlesVectors)74 TYPED_TEST(TypedLuaUtilsTest, HandlesVectors) {
75   std::vector<TypeParam> elements(5);
76   std::generate_n(elements.begin(), 5, [&]() {
77     return this->test_data_generator_.template generate<TypeParam>();
78   });
79 
80   this->PushVector(elements);
81 
82   EXPECT_THAT(this->template ReadVector<TypeParam>(),
83               testing::ContainerEq(elements));
84 }
85 
TYPED_TEST(TypedLuaUtilsTest,HandlesVectorIterators)86 TYPED_TEST(TypedLuaUtilsTest, HandlesVectorIterators) {
87   std::vector<TypeParam> elements(5);
88   std::generate_n(elements.begin(), 5, [&]() {
89     return this->test_data_generator_.template generate<TypeParam>();
90   });
91 
92   this->PushVectorIterator(&elements);
93 
94   EXPECT_THAT(this->template ReadVector<TypeParam>(),
95               testing::ContainerEq(elements));
96 }
97 
TEST_F(LuaUtilsTest,IndexCallback)98 TEST_F(LuaUtilsTest, IndexCallback) {
99   test::TestDataT input_data;
100   input_data.repeated_byte_field = {1, 2};
101   input_data.repeated_ubyte_field = {1, 2};
102   input_data.repeated_int_field = {1, 2};
103   input_data.repeated_uint_field = {1, 2};
104   input_data.repeated_long_field = {1, 2};
105   input_data.repeated_ulong_field = {1, 2};
106   input_data.repeated_bool_field = {true, false};
107   input_data.repeated_float_field = {1, 2};
108   input_data.repeated_double_field = {1, 2};
109   input_data.repeated_string_field = {"1", "2"};
110 
111   flatbuffers::FlatBufferBuilder builder;
112   builder.Finish(test::TestData::Pack(builder, &input_data));
113   const flatbuffers::DetachedBuffer input_buffer = builder.Release();
114   PushFlatbuffer(schema_.get(),
115                  flatbuffers::GetRoot<flatbuffers::Table>(input_buffer.data()));
116   lua_setglobal(state_, "arg");
117   // A Lua script that reads the vectors and return the first value of them.
118   // This should trigger the __index callback.
119   RunScript(R"lua(
120     return {
121         byte_field = arg.repeated_byte_field[1],
122         ubyte_field = arg.repeated_ubyte_field[1],
123         int_field = arg.repeated_int_field[1],
124         uint_field = arg.repeated_uint_field[1],
125         long_field = arg.repeated_long_field[1],
126         ulong_field = arg.repeated_ulong_field[1],
127         bool_field = arg.repeated_bool_field[1],
128         float_field = arg.repeated_float_field[1],
129         double_field = arg.repeated_double_field[1],
130         string_field = arg.repeated_string_field[1],
131     }
132   )lua");
133 
134   // Read the flatbuffer.
135   std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
136   ReadFlatbuffer(/*index=*/-1, buffer.get());
137   const std::string serialized_buffer = buffer->Serialize();
138   std::unique_ptr<test::TestDataT> test_data =
139       LoadAndVerifyMutableFlatbuffer<test::TestData>(buffer->Serialize());
140 
141   EXPECT_THAT(test_data->byte_field, 1);
142   EXPECT_THAT(test_data->ubyte_field, 1);
143   EXPECT_THAT(test_data->int_field, 1);
144   EXPECT_THAT(test_data->uint_field, 1);
145   EXPECT_THAT(test_data->long_field, 1);
146   EXPECT_THAT(test_data->ulong_field, 1);
147   EXPECT_THAT(test_data->bool_field, true);
148   EXPECT_THAT(test_data->float_field, FloatEq(1));
149   EXPECT_THAT(test_data->double_field, DoubleEq(1));
150   EXPECT_THAT(test_data->string_field, "1");
151 }
152 
TEST_F(LuaUtilsTest,PairCallback)153 TEST_F(LuaUtilsTest, PairCallback) {
154   test::TestDataT input_data;
155   input_data.repeated_byte_field = {1, 2};
156   input_data.repeated_ubyte_field = {1, 2};
157   input_data.repeated_int_field = {1, 2};
158   input_data.repeated_uint_field = {1, 2};
159   input_data.repeated_long_field = {1, 2};
160   input_data.repeated_ulong_field = {1, 2};
161   input_data.repeated_bool_field = {true, false};
162   input_data.repeated_float_field = {1, 2};
163   input_data.repeated_double_field = {1, 2};
164   input_data.repeated_string_field = {"1", "2"};
165 
166   flatbuffers::FlatBufferBuilder builder;
167   builder.Finish(test::TestData::Pack(builder, &input_data));
168   const flatbuffers::DetachedBuffer input_buffer = builder.Release();
169   PushFlatbuffer(schema_.get(),
170                  flatbuffers::GetRoot<flatbuffers::Table>(input_buffer.data()));
171   lua_setglobal(state_, "arg");
172 
173   // Iterate the pushed repeated fields by using the pair API and check
174   // if the value is correct. This should trigger the __pair callback.
175   RunScript(R"lua(
176     function equal(table1, table2)
177       for key, value in pairs(table1) do
178           if value ~= table2[key] then
179               return false
180           end
181       end
182       return true
183     end
184 
185     local valid = equal(arg.repeated_byte_field, {[1]=1,[2]=2})
186     valid = valid and equal(arg.repeated_ubyte_field, {[1]=1,[2]=2})
187     valid = valid and equal(arg.repeated_int_field, {[1]=1,[2]=2})
188     valid = valid and equal(arg.repeated_uint_field, {[1]=1,[2]=2})
189     valid = valid and equal(arg.repeated_long_field, {[1]=1,[2]=2})
190     valid = valid and equal(arg.repeated_ulong_field, {[1]=1,[2]=2})
191     valid = valid and equal(arg.repeated_bool_field, {[1]=true,[2]=false})
192     valid = valid and equal(arg.repeated_float_field, {[1]=1,[2]=2})
193     valid = valid and equal(arg.repeated_double_field, {[1]=1,[2]=2})
194     valid = valid and equal(arg.repeated_string_field, {[1]="1",[2]="2"})
195 
196     return {
197         bool_field = valid
198     }
199   )lua");
200 
201   // Read the flatbuffer.
202   std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
203   ReadFlatbuffer(/*index=*/-1, buffer.get());
204   const std::string serialized_buffer = buffer->Serialize();
205   std::unique_ptr<test::TestDataT> test_data =
206       LoadAndVerifyMutableFlatbuffer<test::TestData>(buffer->Serialize());
207 
208   EXPECT_THAT(test_data->bool_field, true);
209 }
210 
TEST_F(LuaUtilsTest,PushAndReadsFlatbufferRoundTrip)211 TEST_F(LuaUtilsTest, PushAndReadsFlatbufferRoundTrip) {
212   // Setup.
213   test::TestDataT input_data;
214   input_data.byte_field = 1;
215   input_data.ubyte_field = 2;
216   input_data.int_field = 10;
217   input_data.uint_field = 11;
218   input_data.long_field = 20;
219   input_data.ulong_field = 21;
220   input_data.bool_field = true;
221   input_data.float_field = 42.1;
222   input_data.double_field = 12.4;
223   input_data.string_field = "hello there";
224   // Nested field.
225   input_data.nested_field = std::make_unique<test::TestDataT>();
226   input_data.nested_field->float_field = 64;
227   input_data.nested_field->string_field = "hello nested";
228   // Repeated fields.
229   input_data.repeated_byte_field = {1, 2, 1};
230   input_data.repeated_byte_field = {1, 2, 1};
231   input_data.repeated_ubyte_field = {2, 4, 2};
232   input_data.repeated_int_field = {1, 2, 3};
233   input_data.repeated_uint_field = {2, 4, 6};
234   input_data.repeated_long_field = {4, 5, 6};
235   input_data.repeated_ulong_field = {8, 10, 12};
236   input_data.repeated_bool_field = {true, false, true};
237   input_data.repeated_float_field = {1.23, 2.34, 3.45};
238   input_data.repeated_double_field = {1.11, 2.22, 3.33};
239   input_data.repeated_string_field = {"a", "bold", "one"};
240   // Repeated nested fields.
241   input_data.repeated_nested_field.push_back(
242       std::make_unique<test::TestDataT>());
243   input_data.repeated_nested_field.back()->string_field = "a";
244   input_data.repeated_nested_field.push_back(
245       std::make_unique<test::TestDataT>());
246   input_data.repeated_nested_field.back()->string_field = "b";
247   input_data.repeated_nested_field.push_back(
248       std::make_unique<test::TestDataT>());
249   input_data.repeated_nested_field.back()->repeated_string_field = {"nested",
250                                                                     "nested2"};
251   flatbuffers::FlatBufferBuilder builder;
252   builder.Finish(test::TestData::Pack(builder, &input_data));
253   const flatbuffers::DetachedBuffer input_buffer = builder.Release();
254   PushFlatbuffer(schema_.get(),
255                  flatbuffers::GetRoot<flatbuffers::Table>(input_buffer.data()));
256   lua_setglobal(state_, "arg");
257 
258   RunScript(R"lua(
259     return {
260         byte_field = arg.byte_field,
261         ubyte_field = arg.ubyte_field,
262         int_field = arg.int_field,
263         uint_field = arg.uint_field,
264         long_field = arg.long_field,
265         ulong_field = arg.ulong_field,
266         bool_field = arg.bool_field,
267         float_field = arg.float_field,
268         double_field = arg.double_field,
269         string_field = arg.string_field,
270         nested_field = {
271           float_field = arg.nested_field.float_field,
272           string_field = arg.nested_field.string_field,
273         },
274         repeated_byte_field = arg.repeated_byte_field,
275         repeated_ubyte_field = arg.repeated_ubyte_field,
276         repeated_int_field = arg.repeated_int_field,
277         repeated_uint_field = arg.repeated_uint_field,
278         repeated_long_field = arg.repeated_long_field,
279         repeated_ulong_field = arg.repeated_ulong_field,
280         repeated_bool_field = arg.repeated_bool_field,
281         repeated_float_field = arg.repeated_float_field,
282         repeated_double_field = arg.repeated_double_field,
283         repeated_string_field = arg.repeated_string_field,
284         repeated_nested_field = {
285           { string_field = arg.repeated_nested_field[1].string_field },
286           { string_field = arg.repeated_nested_field[2].string_field },
287           { repeated_string_field = arg.repeated_nested_field[3].repeated_string_field },
288         },
289     }
290   )lua");
291 
292   // Read the flatbuffer.
293   std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
294   ReadFlatbuffer(/*index=*/-1, buffer.get());
295   const std::string serialized_buffer = buffer->Serialize();
296   std::unique_ptr<test::TestDataT> test_data =
297       LoadAndVerifyMutableFlatbuffer<test::TestData>(buffer->Serialize());
298 
299   EXPECT_THAT(test_data->byte_field, 1);
300   EXPECT_THAT(test_data->ubyte_field, 2);
301   EXPECT_THAT(test_data->int_field, 10);
302   EXPECT_THAT(test_data->uint_field, 11);
303   EXPECT_THAT(test_data->long_field, 20);
304   EXPECT_THAT(test_data->ulong_field, 21);
305   EXPECT_THAT(test_data->bool_field, true);
306   EXPECT_THAT(test_data->float_field, FloatEq(42.1));
307   EXPECT_THAT(test_data->double_field, DoubleEq(12.4));
308   EXPECT_THAT(test_data->string_field, "hello there");
309   EXPECT_THAT(test_data->repeated_byte_field, ElementsAre(1, 2, 1));
310   EXPECT_THAT(test_data->repeated_ubyte_field, ElementsAre(2, 4, 2));
311   EXPECT_THAT(test_data->repeated_int_field, ElementsAre(1, 2, 3));
312   EXPECT_THAT(test_data->repeated_uint_field, ElementsAre(2, 4, 6));
313   EXPECT_THAT(test_data->repeated_long_field, ElementsAre(4, 5, 6));
314   EXPECT_THAT(test_data->repeated_ulong_field, ElementsAre(8, 10, 12));
315   EXPECT_THAT(test_data->repeated_bool_field, ElementsAre(true, false, true));
316   EXPECT_THAT(test_data->repeated_float_field, ElementsAre(1.23, 2.34, 3.45));
317   EXPECT_THAT(test_data->repeated_double_field, ElementsAre(1.11, 2.22, 3.33));
318   EXPECT_THAT(test_data->repeated_string_field,
319               ElementsAre("a", "bold", "one"));
320   // Nested fields.
321   EXPECT_THAT(test_data->nested_field->float_field, FloatEq(64));
322   EXPECT_THAT(test_data->nested_field->string_field, "hello nested");
323   // Repeated nested fields.
324   EXPECT_THAT(test_data->repeated_nested_field[0]->string_field, "a");
325   EXPECT_THAT(test_data->repeated_nested_field[1]->string_field, "b");
326   EXPECT_THAT(test_data->repeated_nested_field[2]->repeated_string_field,
327               ElementsAre("nested", "nested2"));
328 }
329 
TEST_F(LuaUtilsTest,HandlesRepeatedNestedFlatbufferFields)330 TEST_F(LuaUtilsTest, HandlesRepeatedNestedFlatbufferFields) {
331   // Create test flatbuffer.
332   std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
333   RepeatedField* repeated_field = buffer->Repeated("repeated_nested_field");
334   repeated_field->Add()->Set("string_field", "hello");
335   repeated_field->Add()->Set("string_field", "my");
336   MutableFlatbuffer* nested = repeated_field->Add();
337   nested->Set("string_field", "old");
338   RepeatedField* nested_repeated = nested->Repeated("repeated_string_field");
339   nested_repeated->Add("friend");
340   nested_repeated->Add("how");
341   nested_repeated->Add("are");
342   repeated_field->Add()->Set("string_field", "you?");
343   const std::string serialized_buffer = buffer->Serialize();
344   PushFlatbuffer(schema_.get(), flatbuffers::GetRoot<flatbuffers::Table>(
345                                     serialized_buffer.data()));
346   lua_setglobal(state_, "arg");
347 
348   RunScript(R"lua(
349     result = {}
350     for _, nested in pairs(arg.repeated_nested_field) do
351       result[#result + 1] = nested.string_field
352       for _, nested_string in pairs(nested.repeated_string_field) do
353         result[#result + 1] = nested_string
354       end
355     end
356     return result
357   )lua");
358 
359   EXPECT_THAT(
360       ReadVector<std::string>(),
361       ElementsAre("hello", "my", "old", "friend", "how", "are", "you?"));
362 }
363 
TEST_F(LuaUtilsTest,CorrectlyReadsTwoFlatbuffersSimultaneously)364 TEST_F(LuaUtilsTest, CorrectlyReadsTwoFlatbuffersSimultaneously) {
365   // The first flatbuffer.
366   std::unique_ptr<MutableFlatbuffer> buffer = flatbuffer_builder_.NewRoot();
367   buffer->Set("string_field", "first");
368   const std::string serialized_buffer = buffer->Serialize();
369   PushFlatbuffer(schema_.get(), flatbuffers::GetRoot<flatbuffers::Table>(
370                                     serialized_buffer.data()));
371   lua_setglobal(state_, "arg");
372   // The second flatbuffer.
373   std::unique_ptr<MutableFlatbuffer> buffer2 = flatbuffer_builder_.NewRoot();
374   buffer2->Set("string_field", "second");
375   const std::string serialized_buffer2 = buffer2->Serialize();
376   PushFlatbuffer(schema_.get(), flatbuffers::GetRoot<flatbuffers::Table>(
377                                     serialized_buffer2.data()));
378   lua_setglobal(state_, "arg2");
379 
380   RunScript(R"lua(
381     return {arg.string_field, arg2.string_field}
382   )lua");
383 
384   EXPECT_THAT(ReadVector<std::string>(), ElementsAre("first", "second"));
385 }
386 
387 }  // namespace
388 }  // namespace libtextclassifier3
389