1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/feature-processor.h"
18
19 #include "actions/actions_model_generated.h"
20 #include "annotator/model-executor.h"
21 #include "utils/tensor-view.h"
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24
25 namespace libtextclassifier3 {
26 namespace {
27
28 using ::testing::FloatEq;
29 using ::testing::SizeIs;
30
31 // EmbeddingExecutor that always returns features based on
32 // the id of the sparse features.
33 class FakeEmbeddingExecutor : public EmbeddingExecutor {
34 public:
AddEmbedding(const TensorView<int> & sparse_features,float * dest,const int dest_size) const35 bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
36 const int dest_size) const override {
37 TC3_CHECK_GE(dest_size, 4);
38 EXPECT_THAT(sparse_features, SizeIs(1));
39 dest[0] = sparse_features.data()[0];
40 dest[1] = sparse_features.data()[0];
41 dest[2] = -sparse_features.data()[0];
42 dest[3] = -sparse_features.data()[0];
43 return true;
44 }
45
46 private:
47 std::vector<float> storage_;
48 };
49
50 class ActionsFeatureProcessorTest : public ::testing::Test {
51 protected:
ActionsFeatureProcessorTest()52 ActionsFeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
53
PackFeatureProcessorOptions(ActionsTokenFeatureProcessorOptionsT * options) const54 flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
55 ActionsTokenFeatureProcessorOptionsT* options) const {
56 flatbuffers::FlatBufferBuilder builder;
57 builder.Finish(CreateActionsTokenFeatureProcessorOptions(builder, options));
58 return builder.Release();
59 }
60
61 FakeEmbeddingExecutor embedding_executor_;
62 UniLib unilib_;
63 };
64
TEST_F(ActionsFeatureProcessorTest,TokenEmbeddings)65 TEST_F(ActionsFeatureProcessorTest, TokenEmbeddings) {
66 ActionsTokenFeatureProcessorOptionsT options;
67 options.embedding_size = 4;
68 options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
69
70 flatbuffers::DetachedBuffer options_fb =
71 PackFeatureProcessorOptions(&options);
72 ActionsFeatureProcessor feature_processor(
73 flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
74 options_fb.data()),
75 &unilib_);
76
77 Token token("aaa", 0, 3);
78 std::vector<float> token_features;
79 EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
80 &token_features));
81 EXPECT_THAT(token_features, SizeIs(4));
82 }
83
TEST_F(ActionsFeatureProcessorTest,TokenEmbeddingsCaseFeature)84 TEST_F(ActionsFeatureProcessorTest, TokenEmbeddingsCaseFeature) {
85 ActionsTokenFeatureProcessorOptionsT options;
86 options.embedding_size = 4;
87 options.extract_case_feature = true;
88 options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
89
90 flatbuffers::DetachedBuffer options_fb =
91 PackFeatureProcessorOptions(&options);
92 ActionsFeatureProcessor feature_processor(
93 flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
94 options_fb.data()),
95 &unilib_);
96
97 Token token("Aaa", 0, 3);
98 std::vector<float> token_features;
99 EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
100 &token_features));
101 EXPECT_THAT(token_features, SizeIs(5));
102 EXPECT_THAT(token_features[4], FloatEq(1.0));
103 }
104
TEST_F(ActionsFeatureProcessorTest,MultipleTokenEmbeddingsCaseFeature)105 TEST_F(ActionsFeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
106 ActionsTokenFeatureProcessorOptionsT options;
107 options.embedding_size = 4;
108 options.extract_case_feature = true;
109 options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
110
111 flatbuffers::DetachedBuffer options_fb =
112 PackFeatureProcessorOptions(&options);
113 ActionsFeatureProcessor feature_processor(
114 flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
115 options_fb.data()),
116 &unilib_);
117
118 const std::vector<Token> tokens = {Token("Aaa", 0, 3), Token("bbb", 4, 7),
119 Token("Cccc", 8, 12)};
120 std::vector<float> token_features;
121 EXPECT_TRUE(feature_processor.AppendTokenFeatures(
122 tokens, &embedding_executor_, &token_features));
123 EXPECT_THAT(token_features, SizeIs(15));
124 EXPECT_THAT(token_features[4], FloatEq(1.0));
125 EXPECT_THAT(token_features[9], FloatEq(-1.0));
126 EXPECT_THAT(token_features[14], FloatEq(1.0));
127 }
128
129 } // namespace
130 } // namespace libtextclassifier3
131