xref: /aosp_15_r20/external/libtextclassifier/native/actions/feature-processor_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/feature-processor.h"
18 
19 #include "actions/actions_model_generated.h"
20 #include "annotator/model-executor.h"
21 #include "utils/tensor-view.h"
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 
25 namespace libtextclassifier3 {
26 namespace {
27 
28 using ::testing::FloatEq;
29 using ::testing::SizeIs;
30 
31 // EmbeddingExecutor that always returns features based on
32 // the id of the sparse features.
33 class FakeEmbeddingExecutor : public EmbeddingExecutor {
34  public:
AddEmbedding(const TensorView<int> & sparse_features,float * dest,const int dest_size) const35   bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
36                     const int dest_size) const override {
37     TC3_CHECK_GE(dest_size, 4);
38     EXPECT_THAT(sparse_features, SizeIs(1));
39     dest[0] = sparse_features.data()[0];
40     dest[1] = sparse_features.data()[0];
41     dest[2] = -sparse_features.data()[0];
42     dest[3] = -sparse_features.data()[0];
43     return true;
44   }
45 
46  private:
47   std::vector<float> storage_;
48 };
49 
50 class ActionsFeatureProcessorTest : public ::testing::Test {
51  protected:
ActionsFeatureProcessorTest()52   ActionsFeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
53 
PackFeatureProcessorOptions(ActionsTokenFeatureProcessorOptionsT * options) const54   flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
55       ActionsTokenFeatureProcessorOptionsT* options) const {
56     flatbuffers::FlatBufferBuilder builder;
57     builder.Finish(CreateActionsTokenFeatureProcessorOptions(builder, options));
58     return builder.Release();
59   }
60 
61   FakeEmbeddingExecutor embedding_executor_;
62   UniLib unilib_;
63 };
64 
TEST_F(ActionsFeatureProcessorTest,TokenEmbeddings)65 TEST_F(ActionsFeatureProcessorTest, TokenEmbeddings) {
66   ActionsTokenFeatureProcessorOptionsT options;
67   options.embedding_size = 4;
68   options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
69 
70   flatbuffers::DetachedBuffer options_fb =
71       PackFeatureProcessorOptions(&options);
72   ActionsFeatureProcessor feature_processor(
73       flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
74           options_fb.data()),
75       &unilib_);
76 
77   Token token("aaa", 0, 3);
78   std::vector<float> token_features;
79   EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
80                                                     &token_features));
81   EXPECT_THAT(token_features, SizeIs(4));
82 }
83 
TEST_F(ActionsFeatureProcessorTest,TokenEmbeddingsCaseFeature)84 TEST_F(ActionsFeatureProcessorTest, TokenEmbeddingsCaseFeature) {
85   ActionsTokenFeatureProcessorOptionsT options;
86   options.embedding_size = 4;
87   options.extract_case_feature = true;
88   options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
89 
90   flatbuffers::DetachedBuffer options_fb =
91       PackFeatureProcessorOptions(&options);
92   ActionsFeatureProcessor feature_processor(
93       flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
94           options_fb.data()),
95       &unilib_);
96 
97   Token token("Aaa", 0, 3);
98   std::vector<float> token_features;
99   EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
100                                                     &token_features));
101   EXPECT_THAT(token_features, SizeIs(5));
102   EXPECT_THAT(token_features[4], FloatEq(1.0));
103 }
104 
TEST_F(ActionsFeatureProcessorTest,MultipleTokenEmbeddingsCaseFeature)105 TEST_F(ActionsFeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
106   ActionsTokenFeatureProcessorOptionsT options;
107   options.embedding_size = 4;
108   options.extract_case_feature = true;
109   options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
110 
111   flatbuffers::DetachedBuffer options_fb =
112       PackFeatureProcessorOptions(&options);
113   ActionsFeatureProcessor feature_processor(
114       flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
115           options_fb.data()),
116       &unilib_);
117 
118   const std::vector<Token> tokens = {Token("Aaa", 0, 3), Token("bbb", 4, 7),
119                                      Token("Cccc", 8, 12)};
120   std::vector<float> token_features;
121   EXPECT_TRUE(feature_processor.AppendTokenFeatures(
122       tokens, &embedding_executor_, &token_features));
123   EXPECT_THAT(token_features, SizeIs(15));
124   EXPECT_THAT(token_features[4], FloatEq(1.0));
125   EXPECT_THAT(token_features[9], FloatEq(-1.0));
126   EXPECT_THAT(token_features[14], FloatEq(1.0));
127 }
128 
129 }  // namespace
130 }  // namespace libtextclassifier3
131