xref: /aosp_15_r20/external/libtextclassifier/native/annotator/cached-features_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "annotator/cached-features.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include "annotator/model-executor.h"
20*993b0882SAndroid Build Coastguard Worker #include "utils/tensor-view.h"
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker #include "gmock/gmock.h"
23*993b0882SAndroid Build Coastguard Worker #include "gtest/gtest.h"
24*993b0882SAndroid Build Coastguard Worker 
25*993b0882SAndroid Build Coastguard Worker using testing::ElementsAreArray;
26*993b0882SAndroid Build Coastguard Worker using testing::FloatEq;
27*993b0882SAndroid Build Coastguard Worker using testing::Matcher;
28*993b0882SAndroid Build Coastguard Worker 
29*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
30*993b0882SAndroid Build Coastguard Worker namespace {
31*993b0882SAndroid Build Coastguard Worker 
ElementsAreFloat(const std::vector<float> & values)32*993b0882SAndroid Build Coastguard Worker Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
33*993b0882SAndroid Build Coastguard Worker   std::vector<Matcher<float>> matchers;
34*993b0882SAndroid Build Coastguard Worker   for (const float value : values) {
35*993b0882SAndroid Build Coastguard Worker     matchers.push_back(FloatEq(value));
36*993b0882SAndroid Build Coastguard Worker   }
37*993b0882SAndroid Build Coastguard Worker   return ElementsAreArray(matchers);
38*993b0882SAndroid Build Coastguard Worker }
39*993b0882SAndroid Build Coastguard Worker 
MakeFeatures(int num_tokens)40*993b0882SAndroid Build Coastguard Worker std::unique_ptr<std::vector<float>> MakeFeatures(int num_tokens) {
41*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<std::vector<float>> features(new std::vector<float>());
42*993b0882SAndroid Build Coastguard Worker   for (int i = 1; i <= num_tokens; ++i) {
43*993b0882SAndroid Build Coastguard Worker     features->push_back(i * 11.0f);
44*993b0882SAndroid Build Coastguard Worker     features->push_back(-i * 11.0f);
45*993b0882SAndroid Build Coastguard Worker     features->push_back(i * 0.1f);
46*993b0882SAndroid Build Coastguard Worker   }
47*993b0882SAndroid Build Coastguard Worker   return features;
48*993b0882SAndroid Build Coastguard Worker }
49*993b0882SAndroid Build Coastguard Worker 
GetCachedClickContextFeatures(const CachedFeatures & cached_features,int click_pos)50*993b0882SAndroid Build Coastguard Worker std::vector<float> GetCachedClickContextFeatures(
51*993b0882SAndroid Build Coastguard Worker     const CachedFeatures& cached_features, int click_pos) {
52*993b0882SAndroid Build Coastguard Worker   std::vector<float> output_features;
53*993b0882SAndroid Build Coastguard Worker   cached_features.AppendClickContextFeaturesForClick(click_pos,
54*993b0882SAndroid Build Coastguard Worker                                                      &output_features);
55*993b0882SAndroid Build Coastguard Worker   return output_features;
56*993b0882SAndroid Build Coastguard Worker }
57*993b0882SAndroid Build Coastguard Worker 
GetCachedBoundsSensitiveFeatures(const CachedFeatures & cached_features,TokenSpan selected_span)58*993b0882SAndroid Build Coastguard Worker std::vector<float> GetCachedBoundsSensitiveFeatures(
59*993b0882SAndroid Build Coastguard Worker     const CachedFeatures& cached_features, TokenSpan selected_span) {
60*993b0882SAndroid Build Coastguard Worker   std::vector<float> output_features;
61*993b0882SAndroid Build Coastguard Worker   cached_features.AppendBoundsSensitiveFeaturesForSpan(selected_span,
62*993b0882SAndroid Build Coastguard Worker                                                        &output_features);
63*993b0882SAndroid Build Coastguard Worker   return output_features;
64*993b0882SAndroid Build Coastguard Worker }
65*993b0882SAndroid Build Coastguard Worker 
TEST(CachedFeaturesTest,ClickContext)66*993b0882SAndroid Build Coastguard Worker TEST(CachedFeaturesTest, ClickContext) {
67*993b0882SAndroid Build Coastguard Worker   FeatureProcessorOptionsT options;
68*993b0882SAndroid Build Coastguard Worker   options.context_size = 2;
69*993b0882SAndroid Build Coastguard Worker   options.feature_version = 1;
70*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
71*993b0882SAndroid Build Coastguard Worker   builder.Finish(CreateFeatureProcessorOptions(builder, &options));
72*993b0882SAndroid Build Coastguard Worker   flatbuffers::DetachedBuffer options_fb = builder.Release();
73*993b0882SAndroid Build Coastguard Worker 
74*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<std::vector<float>> features = MakeFeatures(9);
75*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<std::vector<float>> padding_features(
76*993b0882SAndroid Build Coastguard Worker       new std::vector<float>{112233.0, -112233.0, 321.0});
77*993b0882SAndroid Build Coastguard Worker 
78*993b0882SAndroid Build Coastguard Worker   const std::unique_ptr<CachedFeatures> cached_features =
79*993b0882SAndroid Build Coastguard Worker       CachedFeatures::Create(
80*993b0882SAndroid Build Coastguard Worker           {3, 10}, std::move(features), std::move(padding_features),
81*993b0882SAndroid Build Coastguard Worker           flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
82*993b0882SAndroid Build Coastguard Worker           /*feature_vector_size=*/3);
83*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(cached_features);
84*993b0882SAndroid Build Coastguard Worker 
85*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 5),
86*993b0882SAndroid Build Coastguard Worker               ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0,
87*993b0882SAndroid Build Coastguard Worker                                 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5}));
88*993b0882SAndroid Build Coastguard Worker 
89*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 6),
90*993b0882SAndroid Build Coastguard Worker               ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0,
91*993b0882SAndroid Build Coastguard Worker                                 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6}));
92*993b0882SAndroid Build Coastguard Worker 
93*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 7),
94*993b0882SAndroid Build Coastguard Worker               ElementsAreFloat({33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0,
95*993b0882SAndroid Build Coastguard Worker                                 0.5, 66.0, -66.0, 0.6, 77.0, -77.0, 0.7}));
96*993b0882SAndroid Build Coastguard Worker }
97*993b0882SAndroid Build Coastguard Worker 
TEST(CachedFeaturesTest,BoundsSensitive)98*993b0882SAndroid Build Coastguard Worker TEST(CachedFeaturesTest, BoundsSensitive) {
99*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<FeatureProcessorOptions_::BoundsSensitiveFeaturesT> config(
100*993b0882SAndroid Build Coastguard Worker       new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
101*993b0882SAndroid Build Coastguard Worker   config->enabled = true;
102*993b0882SAndroid Build Coastguard Worker   config->num_tokens_before = 2;
103*993b0882SAndroid Build Coastguard Worker   config->num_tokens_inside_left = 2;
104*993b0882SAndroid Build Coastguard Worker   config->num_tokens_inside_right = 2;
105*993b0882SAndroid Build Coastguard Worker   config->num_tokens_after = 2;
106*993b0882SAndroid Build Coastguard Worker   config->include_inside_bag = true;
107*993b0882SAndroid Build Coastguard Worker   config->include_inside_length = true;
108*993b0882SAndroid Build Coastguard Worker   FeatureProcessorOptionsT options;
109*993b0882SAndroid Build Coastguard Worker   options.bounds_sensitive_features = std::move(config);
110*993b0882SAndroid Build Coastguard Worker   options.feature_version = 2;
111*993b0882SAndroid Build Coastguard Worker   flatbuffers::FlatBufferBuilder builder;
112*993b0882SAndroid Build Coastguard Worker   builder.Finish(CreateFeatureProcessorOptions(builder, &options));
113*993b0882SAndroid Build Coastguard Worker   flatbuffers::DetachedBuffer options_fb = builder.Release();
114*993b0882SAndroid Build Coastguard Worker 
115*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<std::vector<float>> features = MakeFeatures(9);
116*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<std::vector<float>> padding_features(
117*993b0882SAndroid Build Coastguard Worker       new std::vector<float>{112233.0, -112233.0, 321.0});
118*993b0882SAndroid Build Coastguard Worker 
119*993b0882SAndroid Build Coastguard Worker   const std::unique_ptr<CachedFeatures> cached_features =
120*993b0882SAndroid Build Coastguard Worker       CachedFeatures::Create(
121*993b0882SAndroid Build Coastguard Worker           {3, 9}, std::move(features), std::move(padding_features),
122*993b0882SAndroid Build Coastguard Worker           flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
123*993b0882SAndroid Build Coastguard Worker           /*feature_vector_size=*/3);
124*993b0882SAndroid Build Coastguard Worker   ASSERT_TRUE(cached_features);
125*993b0882SAndroid Build Coastguard Worker 
126*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(
127*993b0882SAndroid Build Coastguard Worker       GetCachedBoundsSensitiveFeatures(*cached_features, {5, 8}),
128*993b0882SAndroid Build Coastguard Worker       ElementsAreFloat({11.0,     -11.0,     0.1,   22.0,  -22.0, 0.2,   33.0,
129*993b0882SAndroid Build Coastguard Worker                         -33.0,    0.3,       44.0,  -44.0, 0.4,   44.0,  -44.0,
130*993b0882SAndroid Build Coastguard Worker                         0.4,      55.0,      -55.0, 0.5,   66.0,  -66.0, 0.6,
131*993b0882SAndroid Build Coastguard Worker                         112233.0, -112233.0, 321.0, 44.0,  -44.0, 0.4,   3.0}));
132*993b0882SAndroid Build Coastguard Worker 
133*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(
134*993b0882SAndroid Build Coastguard Worker       GetCachedBoundsSensitiveFeatures(*cached_features, {5, 7}),
135*993b0882SAndroid Build Coastguard Worker       ElementsAreFloat({11.0,  -11.0, 0.1,   22.0,  -22.0, 0.2,   33.0,
136*993b0882SAndroid Build Coastguard Worker                         -33.0, 0.3,   44.0,  -44.0, 0.4,   33.0,  -33.0,
137*993b0882SAndroid Build Coastguard Worker                         0.3,   44.0,  -44.0, 0.4,   55.0,  -55.0, 0.5,
138*993b0882SAndroid Build Coastguard Worker                         66.0,  -66.0, 0.6,   38.5,  -38.5, 0.35,  2.0}));
139*993b0882SAndroid Build Coastguard Worker 
140*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(
141*993b0882SAndroid Build Coastguard Worker       GetCachedBoundsSensitiveFeatures(*cached_features, {6, 8}),
142*993b0882SAndroid Build Coastguard Worker       ElementsAreFloat({22.0,     -22.0,     0.2,   33.0,  -33.0, 0.3,   44.0,
143*993b0882SAndroid Build Coastguard Worker                         -44.0,    0.4,       55.0,  -55.0, 0.5,   44.0,  -44.0,
144*993b0882SAndroid Build Coastguard Worker                         0.4,      55.0,      -55.0, 0.5,   66.0,  -66.0, 0.6,
145*993b0882SAndroid Build Coastguard Worker                         112233.0, -112233.0, 321.0, 49.5,  -49.5, 0.45,  2.0}));
146*993b0882SAndroid Build Coastguard Worker 
147*993b0882SAndroid Build Coastguard Worker   EXPECT_THAT(
148*993b0882SAndroid Build Coastguard Worker       GetCachedBoundsSensitiveFeatures(*cached_features, {6, 7}),
149*993b0882SAndroid Build Coastguard Worker       ElementsAreFloat({22.0,     -22.0,     0.2,   33.0,     -33.0,     0.3,
150*993b0882SAndroid Build Coastguard Worker                         44.0,     -44.0,     0.4,   112233.0, -112233.0, 321.0,
151*993b0882SAndroid Build Coastguard Worker                         112233.0, -112233.0, 321.0, 44.0,     -44.0,     0.4,
152*993b0882SAndroid Build Coastguard Worker                         55.0,     -55.0,     0.5,   66.0,     -66.0,     0.6,
153*993b0882SAndroid Build Coastguard Worker                         44.0,     -44.0,     0.4,   1.0}));
154*993b0882SAndroid Build Coastguard Worker }
155*993b0882SAndroid Build Coastguard Worker 
156*993b0882SAndroid Build Coastguard Worker }  // namespace
157*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
158