1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "annotator/feature-processor.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <iterator>
20*993b0882SAndroid Build Coastguard Worker #include <set>
21*993b0882SAndroid Build Coastguard Worker #include <vector>
22*993b0882SAndroid Build Coastguard Worker
23*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/strings/utf8.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
26*993b0882SAndroid Build Coastguard Worker
27*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
28*993b0882SAndroid Build Coastguard Worker
29*993b0882SAndroid Build Coastguard Worker namespace internal {
30*993b0882SAndroid Build Coastguard Worker
BuildTokenizer(const FeatureProcessorOptions * options,const UniLib * unilib)31*993b0882SAndroid Build Coastguard Worker Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
32*993b0882SAndroid Build Coastguard Worker const UniLib* unilib) {
33*993b0882SAndroid Build Coastguard Worker std::vector<const TokenizationCodepointRange*> codepoint_config;
34*993b0882SAndroid Build Coastguard Worker if (options->tokenization_codepoint_config() != nullptr) {
35*993b0882SAndroid Build Coastguard Worker codepoint_config.insert(codepoint_config.end(),
36*993b0882SAndroid Build Coastguard Worker options->tokenization_codepoint_config()->begin(),
37*993b0882SAndroid Build Coastguard Worker options->tokenization_codepoint_config()->end());
38*993b0882SAndroid Build Coastguard Worker }
39*993b0882SAndroid Build Coastguard Worker std::vector<const CodepointRange*> internal_codepoint_config;
40*993b0882SAndroid Build Coastguard Worker if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
41*993b0882SAndroid Build Coastguard Worker internal_codepoint_config.insert(
42*993b0882SAndroid Build Coastguard Worker internal_codepoint_config.end(),
43*993b0882SAndroid Build Coastguard Worker options->internal_tokenizer_codepoint_ranges()->begin(),
44*993b0882SAndroid Build Coastguard Worker options->internal_tokenizer_codepoint_ranges()->end());
45*993b0882SAndroid Build Coastguard Worker }
46*993b0882SAndroid Build Coastguard Worker const bool tokenize_on_script_change =
47*993b0882SAndroid Build Coastguard Worker options->tokenization_codepoint_config() != nullptr &&
48*993b0882SAndroid Build Coastguard Worker options->tokenize_on_script_change();
49*993b0882SAndroid Build Coastguard Worker return Tokenizer(options->tokenization_type(), unilib, codepoint_config,
50*993b0882SAndroid Build Coastguard Worker internal_codepoint_config, tokenize_on_script_change,
51*993b0882SAndroid Build Coastguard Worker options->icu_preserve_whitespace_tokens());
52*993b0882SAndroid Build Coastguard Worker }
53*993b0882SAndroid Build Coastguard Worker
BuildTokenFeatureExtractorOptions(const FeatureProcessorOptions * const options)54*993b0882SAndroid Build Coastguard Worker TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
55*993b0882SAndroid Build Coastguard Worker const FeatureProcessorOptions* const options) {
56*993b0882SAndroid Build Coastguard Worker TokenFeatureExtractorOptions extractor_options;
57*993b0882SAndroid Build Coastguard Worker
58*993b0882SAndroid Build Coastguard Worker extractor_options.num_buckets = options->num_buckets();
59*993b0882SAndroid Build Coastguard Worker if (options->chargram_orders() != nullptr) {
60*993b0882SAndroid Build Coastguard Worker for (int order : *options->chargram_orders()) {
61*993b0882SAndroid Build Coastguard Worker extractor_options.chargram_orders.push_back(order);
62*993b0882SAndroid Build Coastguard Worker }
63*993b0882SAndroid Build Coastguard Worker }
64*993b0882SAndroid Build Coastguard Worker extractor_options.max_word_length = options->max_word_length();
65*993b0882SAndroid Build Coastguard Worker extractor_options.extract_case_feature = options->extract_case_feature();
66*993b0882SAndroid Build Coastguard Worker extractor_options.unicode_aware_features = options->unicode_aware_features();
67*993b0882SAndroid Build Coastguard Worker extractor_options.extract_selection_mask_feature =
68*993b0882SAndroid Build Coastguard Worker options->extract_selection_mask_feature();
69*993b0882SAndroid Build Coastguard Worker if (options->regexp_feature() != nullptr) {
70*993b0882SAndroid Build Coastguard Worker for (const auto& regexp_feature : *options->regexp_feature()) {
71*993b0882SAndroid Build Coastguard Worker extractor_options.regexp_features.push_back(regexp_feature->str());
72*993b0882SAndroid Build Coastguard Worker }
73*993b0882SAndroid Build Coastguard Worker }
74*993b0882SAndroid Build Coastguard Worker extractor_options.remap_digits = options->remap_digits();
75*993b0882SAndroid Build Coastguard Worker extractor_options.lowercase_tokens = options->lowercase_tokens();
76*993b0882SAndroid Build Coastguard Worker
77*993b0882SAndroid Build Coastguard Worker if (options->allowed_chargrams() != nullptr) {
78*993b0882SAndroid Build Coastguard Worker for (const auto& chargram : *options->allowed_chargrams()) {
79*993b0882SAndroid Build Coastguard Worker extractor_options.allowed_chargrams.insert(chargram->str());
80*993b0882SAndroid Build Coastguard Worker }
81*993b0882SAndroid Build Coastguard Worker }
82*993b0882SAndroid Build Coastguard Worker return extractor_options;
83*993b0882SAndroid Build Coastguard Worker }
84*993b0882SAndroid Build Coastguard Worker
SplitTokensOnSelectionBoundaries(const CodepointSpan & selection,std::vector<Token> * tokens)85*993b0882SAndroid Build Coastguard Worker void SplitTokensOnSelectionBoundaries(const CodepointSpan& selection,
86*993b0882SAndroid Build Coastguard Worker std::vector<Token>* tokens) {
87*993b0882SAndroid Build Coastguard Worker for (auto it = tokens->begin(); it != tokens->end(); ++it) {
88*993b0882SAndroid Build Coastguard Worker const UnicodeText token_word =
89*993b0882SAndroid Build Coastguard Worker UTF8ToUnicodeText(it->value, /*do_copy=*/false);
90*993b0882SAndroid Build Coastguard Worker
91*993b0882SAndroid Build Coastguard Worker auto last_start = token_word.begin();
92*993b0882SAndroid Build Coastguard Worker int last_start_index = it->start;
93*993b0882SAndroid Build Coastguard Worker std::vector<UnicodeText::const_iterator> split_points;
94*993b0882SAndroid Build Coastguard Worker
95*993b0882SAndroid Build Coastguard Worker // Selection start split point.
96*993b0882SAndroid Build Coastguard Worker if (selection.first > it->start && selection.first < it->end) {
97*993b0882SAndroid Build Coastguard Worker std::advance(last_start, selection.first - last_start_index);
98*993b0882SAndroid Build Coastguard Worker split_points.push_back(last_start);
99*993b0882SAndroid Build Coastguard Worker last_start_index = selection.first;
100*993b0882SAndroid Build Coastguard Worker }
101*993b0882SAndroid Build Coastguard Worker
102*993b0882SAndroid Build Coastguard Worker // Selection end split point.
103*993b0882SAndroid Build Coastguard Worker if (selection.second > it->start && selection.second < it->end) {
104*993b0882SAndroid Build Coastguard Worker std::advance(last_start, selection.second - last_start_index);
105*993b0882SAndroid Build Coastguard Worker split_points.push_back(last_start);
106*993b0882SAndroid Build Coastguard Worker }
107*993b0882SAndroid Build Coastguard Worker
108*993b0882SAndroid Build Coastguard Worker if (!split_points.empty()) {
109*993b0882SAndroid Build Coastguard Worker // Add a final split for the rest of the token unless it's been all
110*993b0882SAndroid Build Coastguard Worker // consumed already.
111*993b0882SAndroid Build Coastguard Worker if (split_points.back() != token_word.end()) {
112*993b0882SAndroid Build Coastguard Worker split_points.push_back(token_word.end());
113*993b0882SAndroid Build Coastguard Worker }
114*993b0882SAndroid Build Coastguard Worker
115*993b0882SAndroid Build Coastguard Worker std::vector<Token> replacement_tokens;
116*993b0882SAndroid Build Coastguard Worker last_start = token_word.begin();
117*993b0882SAndroid Build Coastguard Worker int current_pos = it->start;
118*993b0882SAndroid Build Coastguard Worker for (const auto& split_point : split_points) {
119*993b0882SAndroid Build Coastguard Worker Token new_token(token_word.UTF8Substring(last_start, split_point),
120*993b0882SAndroid Build Coastguard Worker current_pos,
121*993b0882SAndroid Build Coastguard Worker current_pos + std::distance(last_start, split_point));
122*993b0882SAndroid Build Coastguard Worker
123*993b0882SAndroid Build Coastguard Worker last_start = split_point;
124*993b0882SAndroid Build Coastguard Worker current_pos = new_token.end;
125*993b0882SAndroid Build Coastguard Worker
126*993b0882SAndroid Build Coastguard Worker replacement_tokens.push_back(new_token);
127*993b0882SAndroid Build Coastguard Worker }
128*993b0882SAndroid Build Coastguard Worker
129*993b0882SAndroid Build Coastguard Worker it = tokens->erase(it);
130*993b0882SAndroid Build Coastguard Worker it = tokens->insert(it, replacement_tokens.begin(),
131*993b0882SAndroid Build Coastguard Worker replacement_tokens.end());
132*993b0882SAndroid Build Coastguard Worker std::advance(it, replacement_tokens.size() - 1);
133*993b0882SAndroid Build Coastguard Worker }
134*993b0882SAndroid Build Coastguard Worker }
135*993b0882SAndroid Build Coastguard Worker }
136*993b0882SAndroid Build Coastguard Worker
137*993b0882SAndroid Build Coastguard Worker } // namespace internal
138*993b0882SAndroid Build Coastguard Worker
StripTokensFromOtherLines(const std::string & context,const CodepointSpan & span,std::vector<Token> * tokens) const139*993b0882SAndroid Build Coastguard Worker void FeatureProcessor::StripTokensFromOtherLines(
140*993b0882SAndroid Build Coastguard Worker const std::string& context, const CodepointSpan& span,
141*993b0882SAndroid Build Coastguard Worker std::vector<Token>* tokens) const {
142*993b0882SAndroid Build Coastguard Worker const UnicodeText context_unicode = UTF8ToUnicodeText(context,
143*993b0882SAndroid Build Coastguard Worker /*do_copy=*/false);
144*993b0882SAndroid Build Coastguard Worker const auto [span_begin, span_end] =
145*993b0882SAndroid Build Coastguard Worker CodepointSpanToUnicodeTextRange(context_unicode, span);
146*993b0882SAndroid Build Coastguard Worker StripTokensFromOtherLines(context_unicode, span_begin, span_end, span,
147*993b0882SAndroid Build Coastguard Worker tokens);
148*993b0882SAndroid Build Coastguard Worker }
149*993b0882SAndroid Build Coastguard Worker
StripTokensFromOtherLines(const UnicodeText & context_unicode,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const CodepointSpan & span,std::vector<Token> * tokens) const150*993b0882SAndroid Build Coastguard Worker void FeatureProcessor::StripTokensFromOtherLines(
151*993b0882SAndroid Build Coastguard Worker const UnicodeText& context_unicode,
152*993b0882SAndroid Build Coastguard Worker const UnicodeText::const_iterator& span_begin,
153*993b0882SAndroid Build Coastguard Worker const UnicodeText::const_iterator& span_end, const CodepointSpan& span,
154*993b0882SAndroid Build Coastguard Worker std::vector<Token>* tokens) const {
155*993b0882SAndroid Build Coastguard Worker std::vector<UnicodeTextRange> lines =
156*993b0882SAndroid Build Coastguard Worker SplitContext(context_unicode, options_->use_pipe_character_for_newline());
157*993b0882SAndroid Build Coastguard Worker
158*993b0882SAndroid Build Coastguard Worker for (const UnicodeTextRange& line : lines) {
159*993b0882SAndroid Build Coastguard Worker // Find the line that completely contains the span.
160*993b0882SAndroid Build Coastguard Worker if (line.first <= span_begin && line.second >= span_end) {
161*993b0882SAndroid Build Coastguard Worker const CodepointIndex last_line_begin_index =
162*993b0882SAndroid Build Coastguard Worker std::distance(context_unicode.begin(), line.first);
163*993b0882SAndroid Build Coastguard Worker const CodepointIndex last_line_end_index =
164*993b0882SAndroid Build Coastguard Worker last_line_begin_index + std::distance(line.first, line.second);
165*993b0882SAndroid Build Coastguard Worker
166*993b0882SAndroid Build Coastguard Worker for (auto token = tokens->begin(); token != tokens->end();) {
167*993b0882SAndroid Build Coastguard Worker if (token->start >= last_line_begin_index &&
168*993b0882SAndroid Build Coastguard Worker token->end <= last_line_end_index) {
169*993b0882SAndroid Build Coastguard Worker ++token;
170*993b0882SAndroid Build Coastguard Worker } else {
171*993b0882SAndroid Build Coastguard Worker token = tokens->erase(token);
172*993b0882SAndroid Build Coastguard Worker }
173*993b0882SAndroid Build Coastguard Worker }
174*993b0882SAndroid Build Coastguard Worker }
175*993b0882SAndroid Build Coastguard Worker }
176*993b0882SAndroid Build Coastguard Worker }
177*993b0882SAndroid Build Coastguard Worker
GetDefaultCollection() const178*993b0882SAndroid Build Coastguard Worker std::string FeatureProcessor::GetDefaultCollection() const {
179*993b0882SAndroid Build Coastguard Worker if (options_->default_collection() < 0 ||
180*993b0882SAndroid Build Coastguard Worker options_->collections() == nullptr ||
181*993b0882SAndroid Build Coastguard Worker options_->default_collection() >= options_->collections()->size()) {
182*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR)
183*993b0882SAndroid Build Coastguard Worker << "Invalid or missing default collection. Returning empty string.";
184*993b0882SAndroid Build Coastguard Worker return "";
185*993b0882SAndroid Build Coastguard Worker }
186*993b0882SAndroid Build Coastguard Worker return (*options_->collections())[options_->default_collection()]->str();
187*993b0882SAndroid Build Coastguard Worker }
188*993b0882SAndroid Build Coastguard Worker
Tokenize(const std::string & text) const189*993b0882SAndroid Build Coastguard Worker std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const {
190*993b0882SAndroid Build Coastguard Worker return tokenizer_.Tokenize(text);
191*993b0882SAndroid Build Coastguard Worker }
192*993b0882SAndroid Build Coastguard Worker
Tokenize(const UnicodeText & text_unicode) const193*993b0882SAndroid Build Coastguard Worker std::vector<Token> FeatureProcessor::Tokenize(
194*993b0882SAndroid Build Coastguard Worker const UnicodeText& text_unicode) const {
195*993b0882SAndroid Build Coastguard Worker return tokenizer_.Tokenize(text_unicode);
196*993b0882SAndroid Build Coastguard Worker }
197*993b0882SAndroid Build Coastguard Worker
LabelToSpan(const int label,const VectorSpan<Token> & tokens,CodepointSpan * span) const198*993b0882SAndroid Build Coastguard Worker bool FeatureProcessor::LabelToSpan(const int label,
199*993b0882SAndroid Build Coastguard Worker const VectorSpan<Token>& tokens,
200*993b0882SAndroid Build Coastguard Worker CodepointSpan* span) const {
201*993b0882SAndroid Build Coastguard Worker if (tokens.size() != GetNumContextTokens()) {
202*993b0882SAndroid Build Coastguard Worker return false;
203*993b0882SAndroid Build Coastguard Worker }
204*993b0882SAndroid Build Coastguard Worker
205*993b0882SAndroid Build Coastguard Worker TokenSpan token_span;
206*993b0882SAndroid Build Coastguard Worker if (!LabelToTokenSpan(label, &token_span)) {
207*993b0882SAndroid Build Coastguard Worker return false;
208*993b0882SAndroid Build Coastguard Worker }
209*993b0882SAndroid Build Coastguard Worker
210*993b0882SAndroid Build Coastguard Worker const int result_begin_token_index = token_span.first;
211*993b0882SAndroid Build Coastguard Worker const Token& result_begin_token =
212*993b0882SAndroid Build Coastguard Worker tokens[options_->context_size() - result_begin_token_index];
213*993b0882SAndroid Build Coastguard Worker const int result_begin_codepoint = result_begin_token.start;
214*993b0882SAndroid Build Coastguard Worker const int result_end_token_index = token_span.second;
215*993b0882SAndroid Build Coastguard Worker const Token& result_end_token =
216*993b0882SAndroid Build Coastguard Worker tokens[options_->context_size() + result_end_token_index];
217*993b0882SAndroid Build Coastguard Worker const int result_end_codepoint = result_end_token.end;
218*993b0882SAndroid Build Coastguard Worker
219*993b0882SAndroid Build Coastguard Worker if (result_begin_codepoint == kInvalidIndex ||
220*993b0882SAndroid Build Coastguard Worker result_end_codepoint == kInvalidIndex) {
221*993b0882SAndroid Build Coastguard Worker *span = CodepointSpan::kInvalid;
222*993b0882SAndroid Build Coastguard Worker } else {
223*993b0882SAndroid Build Coastguard Worker const UnicodeText token_begin_unicode =
224*993b0882SAndroid Build Coastguard Worker UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
225*993b0882SAndroid Build Coastguard Worker UnicodeText::const_iterator token_begin = token_begin_unicode.begin();
226*993b0882SAndroid Build Coastguard Worker const UnicodeText token_end_unicode =
227*993b0882SAndroid Build Coastguard Worker UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false);
228*993b0882SAndroid Build Coastguard Worker UnicodeText::const_iterator token_end = token_end_unicode.end();
229*993b0882SAndroid Build Coastguard Worker
230*993b0882SAndroid Build Coastguard Worker const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
231*993b0882SAndroid Build Coastguard Worker token_begin, token_begin_unicode.end(),
232*993b0882SAndroid Build Coastguard Worker /*count_from_beginning=*/true);
233*993b0882SAndroid Build Coastguard Worker const int end_ignored =
234*993b0882SAndroid Build Coastguard Worker CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end,
235*993b0882SAndroid Build Coastguard Worker /*count_from_beginning=*/false);
236*993b0882SAndroid Build Coastguard Worker // In case everything would be stripped, set the span to the original
237*993b0882SAndroid Build Coastguard Worker // beginning and zero length.
238*993b0882SAndroid Build Coastguard Worker if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
239*993b0882SAndroid Build Coastguard Worker *span = {result_begin_codepoint, result_begin_codepoint};
240*993b0882SAndroid Build Coastguard Worker } else {
241*993b0882SAndroid Build Coastguard Worker *span = CodepointSpan(result_begin_codepoint + begin_ignored,
242*993b0882SAndroid Build Coastguard Worker result_end_codepoint - end_ignored);
243*993b0882SAndroid Build Coastguard Worker }
244*993b0882SAndroid Build Coastguard Worker }
245*993b0882SAndroid Build Coastguard Worker return true;
246*993b0882SAndroid Build Coastguard Worker }
247*993b0882SAndroid Build Coastguard Worker
LabelToTokenSpan(const int label,TokenSpan * token_span) const248*993b0882SAndroid Build Coastguard Worker bool FeatureProcessor::LabelToTokenSpan(const int label,
249*993b0882SAndroid Build Coastguard Worker TokenSpan* token_span) const {
250*993b0882SAndroid Build Coastguard Worker if (label >= 0 && label < label_to_selection_.size()) {
251*993b0882SAndroid Build Coastguard Worker *token_span = label_to_selection_[label];
252*993b0882SAndroid Build Coastguard Worker return true;
253*993b0882SAndroid Build Coastguard Worker } else {
254*993b0882SAndroid Build Coastguard Worker return false;
255*993b0882SAndroid Build Coastguard Worker }
256*993b0882SAndroid Build Coastguard Worker }
257*993b0882SAndroid Build Coastguard Worker
SpanToLabel(const CodepointSpan & span,const std::vector<Token> & tokens,int * label) const258*993b0882SAndroid Build Coastguard Worker bool FeatureProcessor::SpanToLabel(const CodepointSpan& span,
259*993b0882SAndroid Build Coastguard Worker const std::vector<Token>& tokens,
260*993b0882SAndroid Build Coastguard Worker int* label) const {
261*993b0882SAndroid Build Coastguard Worker if (tokens.size() != GetNumContextTokens()) {
262*993b0882SAndroid Build Coastguard Worker return false;
263*993b0882SAndroid Build Coastguard Worker }
264*993b0882SAndroid Build Coastguard Worker
265*993b0882SAndroid Build Coastguard Worker const int click_position =
266*993b0882SAndroid Build Coastguard Worker options_->context_size(); // Click is always in the middle.
267*993b0882SAndroid Build Coastguard Worker const int padding = options_->context_size() - options_->max_selection_span();
268*993b0882SAndroid Build Coastguard Worker
269*993b0882SAndroid Build Coastguard Worker int span_left = 0;
270*993b0882SAndroid Build Coastguard Worker for (int i = click_position - 1; i >= padding; i--) {
271*993b0882SAndroid Build Coastguard Worker if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) {
272*993b0882SAndroid Build Coastguard Worker ++span_left;
273*993b0882SAndroid Build Coastguard Worker } else {
274*993b0882SAndroid Build Coastguard Worker break;
275*993b0882SAndroid Build Coastguard Worker }
276*993b0882SAndroid Build Coastguard Worker }
277*993b0882SAndroid Build Coastguard Worker
278*993b0882SAndroid Build Coastguard Worker int span_right = 0;
279*993b0882SAndroid Build Coastguard Worker for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
280*993b0882SAndroid Build Coastguard Worker if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) {
281*993b0882SAndroid Build Coastguard Worker ++span_right;
282*993b0882SAndroid Build Coastguard Worker } else {
283*993b0882SAndroid Build Coastguard Worker break;
284*993b0882SAndroid Build Coastguard Worker }
285*993b0882SAndroid Build Coastguard Worker }
286*993b0882SAndroid Build Coastguard Worker
287*993b0882SAndroid Build Coastguard Worker // Check that the spanned tokens cover the whole span.
288*993b0882SAndroid Build Coastguard Worker bool tokens_match_span;
289*993b0882SAndroid Build Coastguard Worker const CodepointIndex tokens_start = tokens[click_position - span_left].start;
290*993b0882SAndroid Build Coastguard Worker const CodepointIndex tokens_end = tokens[click_position + span_right].end;
291*993b0882SAndroid Build Coastguard Worker if (options_->snap_label_span_boundaries_to_containing_tokens()) {
292*993b0882SAndroid Build Coastguard Worker tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
293*993b0882SAndroid Build Coastguard Worker } else {
294*993b0882SAndroid Build Coastguard Worker const UnicodeText token_left_unicode = UTF8ToUnicodeText(
295*993b0882SAndroid Build Coastguard Worker tokens[click_position - span_left].value, /*do_copy=*/false);
296*993b0882SAndroid Build Coastguard Worker const UnicodeText token_right_unicode = UTF8ToUnicodeText(
297*993b0882SAndroid Build Coastguard Worker tokens[click_position + span_right].value, /*do_copy=*/false);
298*993b0882SAndroid Build Coastguard Worker
299*993b0882SAndroid Build Coastguard Worker UnicodeText::const_iterator span_begin = token_left_unicode.begin();
300*993b0882SAndroid Build Coastguard Worker UnicodeText::const_iterator span_end = token_right_unicode.end();
301*993b0882SAndroid Build Coastguard Worker
302*993b0882SAndroid Build Coastguard Worker const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
303*993b0882SAndroid Build Coastguard Worker span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
304*993b0882SAndroid Build Coastguard Worker const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
305*993b0882SAndroid Build Coastguard Worker token_right_unicode.begin(), span_end,
306*993b0882SAndroid Build Coastguard Worker /*count_from_beginning=*/false);
307*993b0882SAndroid Build Coastguard Worker
308*993b0882SAndroid Build Coastguard Worker tokens_match_span = tokens_start <= span.first &&
309*993b0882SAndroid Build Coastguard Worker tokens_start + num_punctuation_start >= span.first &&
310*993b0882SAndroid Build Coastguard Worker tokens_end >= span.second &&
311*993b0882SAndroid Build Coastguard Worker tokens_end - num_punctuation_end <= span.second;
312*993b0882SAndroid Build Coastguard Worker }
313*993b0882SAndroid Build Coastguard Worker
314*993b0882SAndroid Build Coastguard Worker if (tokens_match_span) {
315*993b0882SAndroid Build Coastguard Worker *label = TokenSpanToLabel({span_left, span_right});
316*993b0882SAndroid Build Coastguard Worker } else {
317*993b0882SAndroid Build Coastguard Worker *label = kInvalidLabel;
318*993b0882SAndroid Build Coastguard Worker }
319*993b0882SAndroid Build Coastguard Worker
320*993b0882SAndroid Build Coastguard Worker return true;
321*993b0882SAndroid Build Coastguard Worker }
322*993b0882SAndroid Build Coastguard Worker
TokenSpanToLabel(const TokenSpan & token_span) const323*993b0882SAndroid Build Coastguard Worker int FeatureProcessor::TokenSpanToLabel(const TokenSpan& token_span) const {
324*993b0882SAndroid Build Coastguard Worker auto it = selection_to_label_.find(token_span);
325*993b0882SAndroid Build Coastguard Worker if (it != selection_to_label_.end()) {
326*993b0882SAndroid Build Coastguard Worker return it->second;
327*993b0882SAndroid Build Coastguard Worker } else {
328*993b0882SAndroid Build Coastguard Worker return kInvalidLabel;
329*993b0882SAndroid Build Coastguard Worker }
330*993b0882SAndroid Build Coastguard Worker }
331*993b0882SAndroid Build Coastguard Worker
CodepointSpanToTokenSpan(const std::vector<Token> & selectable_tokens,const CodepointSpan & codepoint_span,bool snap_boundaries_to_containing_tokens)332*993b0882SAndroid Build Coastguard Worker TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
333*993b0882SAndroid Build Coastguard Worker const CodepointSpan& codepoint_span,
334*993b0882SAndroid Build Coastguard Worker bool snap_boundaries_to_containing_tokens) {
335*993b0882SAndroid Build Coastguard Worker const int codepoint_start = codepoint_span.first;
336*993b0882SAndroid Build Coastguard Worker const int codepoint_end = codepoint_span.second;
337*993b0882SAndroid Build Coastguard Worker
338*993b0882SAndroid Build Coastguard Worker TokenIndex start_token = kInvalidIndex;
339*993b0882SAndroid Build Coastguard Worker TokenIndex end_token = kInvalidIndex;
340*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < selectable_tokens.size(); ++i) {
341*993b0882SAndroid Build Coastguard Worker bool is_token_in_span;
342*993b0882SAndroid Build Coastguard Worker if (snap_boundaries_to_containing_tokens) {
343*993b0882SAndroid Build Coastguard Worker is_token_in_span = codepoint_start < selectable_tokens[i].end &&
344*993b0882SAndroid Build Coastguard Worker codepoint_end > selectable_tokens[i].start;
345*993b0882SAndroid Build Coastguard Worker } else {
346*993b0882SAndroid Build Coastguard Worker is_token_in_span = codepoint_start <= selectable_tokens[i].start &&
347*993b0882SAndroid Build Coastguard Worker codepoint_end >= selectable_tokens[i].end;
348*993b0882SAndroid Build Coastguard Worker }
349*993b0882SAndroid Build Coastguard Worker if (is_token_in_span && !selectable_tokens[i].is_padding) {
350*993b0882SAndroid Build Coastguard Worker if (start_token == kInvalidIndex) {
351*993b0882SAndroid Build Coastguard Worker start_token = i;
352*993b0882SAndroid Build Coastguard Worker }
353*993b0882SAndroid Build Coastguard Worker end_token = i + 1;
354*993b0882SAndroid Build Coastguard Worker }
355*993b0882SAndroid Build Coastguard Worker }
356*993b0882SAndroid Build Coastguard Worker return {start_token, end_token};
357*993b0882SAndroid Build Coastguard Worker }
358*993b0882SAndroid Build Coastguard Worker
TokenSpanToCodepointSpan(const std::vector<Token> & selectable_tokens,const TokenSpan & token_span)359*993b0882SAndroid Build Coastguard Worker CodepointSpan TokenSpanToCodepointSpan(
360*993b0882SAndroid Build Coastguard Worker const std::vector<Token>& selectable_tokens, const TokenSpan& token_span) {
361*993b0882SAndroid Build Coastguard Worker return {selectable_tokens[token_span.first].start,
362*993b0882SAndroid Build Coastguard Worker selectable_tokens[token_span.second - 1].end};
363*993b0882SAndroid Build Coastguard Worker }
364*993b0882SAndroid Build Coastguard Worker
CodepointSpanToUnicodeTextRange(const UnicodeText & unicode_text,const CodepointSpan & span)365*993b0882SAndroid Build Coastguard Worker UnicodeTextRange CodepointSpanToUnicodeTextRange(
366*993b0882SAndroid Build Coastguard Worker const UnicodeText& unicode_text, const CodepointSpan& span) {
367*993b0882SAndroid Build Coastguard Worker auto begin = unicode_text.begin();
368*993b0882SAndroid Build Coastguard Worker if (span.first > 0) {
369*993b0882SAndroid Build Coastguard Worker std::advance(begin, span.first);
370*993b0882SAndroid Build Coastguard Worker }
371*993b0882SAndroid Build Coastguard Worker auto end = unicode_text.begin();
372*993b0882SAndroid Build Coastguard Worker if (span.second > 0) {
373*993b0882SAndroid Build Coastguard Worker std::advance(end, span.second);
374*993b0882SAndroid Build Coastguard Worker }
375*993b0882SAndroid Build Coastguard Worker return {begin, end};
376*993b0882SAndroid Build Coastguard Worker }
377*993b0882SAndroid Build Coastguard Worker
378*993b0882SAndroid Build Coastguard Worker namespace {
379*993b0882SAndroid Build Coastguard Worker
380*993b0882SAndroid Build Coastguard Worker // Finds a single token that completely contains the given span.
FindTokenThatContainsSpan(const std::vector<Token> & selectable_tokens,const CodepointSpan & codepoint_span)381*993b0882SAndroid Build Coastguard Worker int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
382*993b0882SAndroid Build Coastguard Worker const CodepointSpan& codepoint_span) {
383*993b0882SAndroid Build Coastguard Worker const int codepoint_start = codepoint_span.first;
384*993b0882SAndroid Build Coastguard Worker const int codepoint_end = codepoint_span.second;
385*993b0882SAndroid Build Coastguard Worker
386*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < selectable_tokens.size(); ++i) {
387*993b0882SAndroid Build Coastguard Worker if (codepoint_start >= selectable_tokens[i].start &&
388*993b0882SAndroid Build Coastguard Worker codepoint_end <= selectable_tokens[i].end) {
389*993b0882SAndroid Build Coastguard Worker return i;
390*993b0882SAndroid Build Coastguard Worker }
391*993b0882SAndroid Build Coastguard Worker }
392*993b0882SAndroid Build Coastguard Worker return kInvalidIndex;
393*993b0882SAndroid Build Coastguard Worker }
394*993b0882SAndroid Build Coastguard Worker
395*993b0882SAndroid Build Coastguard Worker } // namespace
396*993b0882SAndroid Build Coastguard Worker
397*993b0882SAndroid Build Coastguard Worker namespace internal {
398*993b0882SAndroid Build Coastguard Worker
CenterTokenFromClick(const CodepointSpan & span,const std::vector<Token> & selectable_tokens)399*993b0882SAndroid Build Coastguard Worker int CenterTokenFromClick(const CodepointSpan& span,
400*993b0882SAndroid Build Coastguard Worker const std::vector<Token>& selectable_tokens) {
401*993b0882SAndroid Build Coastguard Worker const TokenSpan token_span =
402*993b0882SAndroid Build Coastguard Worker CodepointSpanToTokenSpan(selectable_tokens, span);
403*993b0882SAndroid Build Coastguard Worker int range_begin = token_span.first;
404*993b0882SAndroid Build Coastguard Worker int range_end = token_span.second;
405*993b0882SAndroid Build Coastguard Worker
406*993b0882SAndroid Build Coastguard Worker // If no exact match was found, try finding a token that completely contains
407*993b0882SAndroid Build Coastguard Worker // the click span. This is useful e.g. when Android builds the selection
408*993b0882SAndroid Build Coastguard Worker // using ICU tokenization, and ends up with only a portion of our space-
409*993b0882SAndroid Build Coastguard Worker // separated token. E.g. for "(857)" Android would select "857".
410*993b0882SAndroid Build Coastguard Worker if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
411*993b0882SAndroid Build Coastguard Worker int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
412*993b0882SAndroid Build Coastguard Worker if (token_index != kInvalidIndex) {
413*993b0882SAndroid Build Coastguard Worker range_begin = token_index;
414*993b0882SAndroid Build Coastguard Worker range_end = token_index + 1;
415*993b0882SAndroid Build Coastguard Worker }
416*993b0882SAndroid Build Coastguard Worker }
417*993b0882SAndroid Build Coastguard Worker
418*993b0882SAndroid Build Coastguard Worker // We only allow clicks that are exactly 1 selectable token.
419*993b0882SAndroid Build Coastguard Worker if (range_end - range_begin == 1) {
420*993b0882SAndroid Build Coastguard Worker return range_begin;
421*993b0882SAndroid Build Coastguard Worker } else {
422*993b0882SAndroid Build Coastguard Worker return kInvalidIndex;
423*993b0882SAndroid Build Coastguard Worker }
424*993b0882SAndroid Build Coastguard Worker }
425*993b0882SAndroid Build Coastguard Worker
CenterTokenFromMiddleOfSelection(const CodepointSpan & span,const std::vector<Token> & selectable_tokens)426*993b0882SAndroid Build Coastguard Worker int CenterTokenFromMiddleOfSelection(
427*993b0882SAndroid Build Coastguard Worker const CodepointSpan& span, const std::vector<Token>& selectable_tokens) {
428*993b0882SAndroid Build Coastguard Worker const TokenSpan token_span =
429*993b0882SAndroid Build Coastguard Worker CodepointSpanToTokenSpan(selectable_tokens, span);
430*993b0882SAndroid Build Coastguard Worker const int range_begin = token_span.first;
431*993b0882SAndroid Build Coastguard Worker const int range_end = token_span.second;
432*993b0882SAndroid Build Coastguard Worker
433*993b0882SAndroid Build Coastguard Worker // Center the clicked token in the selection range.
434*993b0882SAndroid Build Coastguard Worker if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
435*993b0882SAndroid Build Coastguard Worker return (range_begin + range_end - 1) / 2;
436*993b0882SAndroid Build Coastguard Worker } else {
437*993b0882SAndroid Build Coastguard Worker return kInvalidIndex;
438*993b0882SAndroid Build Coastguard Worker }
439*993b0882SAndroid Build Coastguard Worker }
440*993b0882SAndroid Build Coastguard Worker
441*993b0882SAndroid Build Coastguard Worker } // namespace internal
442*993b0882SAndroid Build Coastguard Worker
FindCenterToken(const CodepointSpan & span,const std::vector<Token> & tokens) const443*993b0882SAndroid Build Coastguard Worker int FeatureProcessor::FindCenterToken(const CodepointSpan& span,
444*993b0882SAndroid Build Coastguard Worker const std::vector<Token>& tokens) const {
445*993b0882SAndroid Build Coastguard Worker if (options_->center_token_selection_method() ==
446*993b0882SAndroid Build Coastguard Worker FeatureProcessorOptions_::
447*993b0882SAndroid Build Coastguard Worker CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) {
448*993b0882SAndroid Build Coastguard Worker return internal::CenterTokenFromClick(span, tokens);
449*993b0882SAndroid Build Coastguard Worker } else if (options_->center_token_selection_method() ==
450*993b0882SAndroid Build Coastguard Worker FeatureProcessorOptions_::
451*993b0882SAndroid Build Coastguard Worker CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) {
452*993b0882SAndroid Build Coastguard Worker return internal::CenterTokenFromMiddleOfSelection(span, tokens);
453*993b0882SAndroid Build Coastguard Worker } else if (options_->center_token_selection_method() ==
454*993b0882SAndroid Build Coastguard Worker FeatureProcessorOptions_::
455*993b0882SAndroid Build Coastguard Worker CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) {
456*993b0882SAndroid Build Coastguard Worker // TODO(zilka): Remove once we have new models on the device.
457*993b0882SAndroid Build Coastguard Worker // It uses the fact that sharing model use
458*993b0882SAndroid Build Coastguard Worker // split_tokens_on_selection_boundaries and selection not. So depending on
459*993b0882SAndroid Build Coastguard Worker // this we select the right way of finding the click location.
460*993b0882SAndroid Build Coastguard Worker if (!options_->split_tokens_on_selection_boundaries()) {
461*993b0882SAndroid Build Coastguard Worker // SmartSelection model.
462*993b0882SAndroid Build Coastguard Worker return internal::CenterTokenFromClick(span, tokens);
463*993b0882SAndroid Build Coastguard Worker } else {
464*993b0882SAndroid Build Coastguard Worker // SmartSharing model.
465*993b0882SAndroid Build Coastguard Worker return internal::CenterTokenFromMiddleOfSelection(span, tokens);
466*993b0882SAndroid Build Coastguard Worker }
467*993b0882SAndroid Build Coastguard Worker } else {
468*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Invalid center token selection method.";
469*993b0882SAndroid Build Coastguard Worker return kInvalidIndex;
470*993b0882SAndroid Build Coastguard Worker }
471*993b0882SAndroid Build Coastguard Worker }
472*993b0882SAndroid Build Coastguard Worker
SelectionLabelSpans(const VectorSpan<Token> tokens,std::vector<CodepointSpan> * selection_label_spans) const473*993b0882SAndroid Build Coastguard Worker bool FeatureProcessor::SelectionLabelSpans(
474*993b0882SAndroid Build Coastguard Worker const VectorSpan<Token> tokens,
475*993b0882SAndroid Build Coastguard Worker std::vector<CodepointSpan>* selection_label_spans) const {
476*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < label_to_selection_.size(); ++i) {
477*993b0882SAndroid Build Coastguard Worker CodepointSpan span = CodepointSpan::kInvalid;
478*993b0882SAndroid Build Coastguard Worker if (!LabelToSpan(i, tokens, &span)) {
479*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not convert label to span: " << i;
480*993b0882SAndroid Build Coastguard Worker return false;
481*993b0882SAndroid Build Coastguard Worker }
482*993b0882SAndroid Build Coastguard Worker selection_label_spans->push_back(span);
483*993b0882SAndroid Build Coastguard Worker }
484*993b0882SAndroid Build Coastguard Worker return true;
485*993b0882SAndroid Build Coastguard Worker }
486*993b0882SAndroid Build Coastguard Worker
SelectionLabelRelativeTokenSpans(std::vector<TokenSpan> * selection_label_relative_token_spans) const487*993b0882SAndroid Build Coastguard Worker bool FeatureProcessor::SelectionLabelRelativeTokenSpans(
488*993b0882SAndroid Build Coastguard Worker std::vector<TokenSpan>* selection_label_relative_token_spans) const {
489*993b0882SAndroid Build Coastguard Worker selection_label_relative_token_spans->assign(label_to_selection_.begin(),
490*993b0882SAndroid Build Coastguard Worker label_to_selection_.end());
491*993b0882SAndroid Build Coastguard Worker return true;
492*993b0882SAndroid Build Coastguard Worker }
493*993b0882SAndroid Build Coastguard Worker
PrepareIgnoredSpanBoundaryCodepoints()494*993b0882SAndroid Build Coastguard Worker void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
495*993b0882SAndroid Build Coastguard Worker if (options_->ignored_span_boundary_codepoints() != nullptr) {
496*993b0882SAndroid Build Coastguard Worker for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
497*993b0882SAndroid Build Coastguard Worker ignored_span_boundary_codepoints_.insert(codepoint);
498*993b0882SAndroid Build Coastguard Worker }
499*993b0882SAndroid Build Coastguard Worker }
500*993b0882SAndroid Build Coastguard Worker }
501*993b0882SAndroid Build Coastguard Worker
CountIgnoredSpanBoundaryCodepoints(const UnicodeText::const_iterator & span_start,const UnicodeText::const_iterator & span_end,bool count_from_beginning) const502*993b0882SAndroid Build Coastguard Worker int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
503*993b0882SAndroid Build Coastguard Worker const UnicodeText::const_iterator& span_start,
504*993b0882SAndroid Build Coastguard Worker const UnicodeText::const_iterator& span_end,
505*993b0882SAndroid Build Coastguard Worker bool count_from_beginning) const {
506*993b0882SAndroid Build Coastguard Worker if (span_start == span_end) {
507*993b0882SAndroid Build Coastguard Worker return 0;
508*993b0882SAndroid Build Coastguard Worker }
509*993b0882SAndroid Build Coastguard Worker
510*993b0882SAndroid Build Coastguard Worker UnicodeText::const_iterator it;
511*993b0882SAndroid Build Coastguard Worker UnicodeText::const_iterator it_last;
512*993b0882SAndroid Build Coastguard Worker if (count_from_beginning) {
513*993b0882SAndroid Build Coastguard Worker it = span_start;
514*993b0882SAndroid Build Coastguard Worker it_last = span_end;
515*993b0882SAndroid Build Coastguard Worker // We can assume that the string is non-zero length because of the check
516*993b0882SAndroid Build Coastguard Worker // above, thus the decrement is always valid here.
517*993b0882SAndroid Build Coastguard Worker --it_last;
518*993b0882SAndroid Build Coastguard Worker } else {
519*993b0882SAndroid Build Coastguard Worker it = span_end;
520*993b0882SAndroid Build Coastguard Worker it_last = span_start;
521*993b0882SAndroid Build Coastguard Worker // We can assume that the string is non-zero length because of the check
522*993b0882SAndroid Build Coastguard Worker // above, thus the decrement is always valid here.
523*993b0882SAndroid Build Coastguard Worker --it;
524*993b0882SAndroid Build Coastguard Worker }
525*993b0882SAndroid Build Coastguard Worker
526*993b0882SAndroid Build Coastguard Worker // Move until we encounter a non-ignored character.
527*993b0882SAndroid Build Coastguard Worker int num_ignored = 0;
528*993b0882SAndroid Build Coastguard Worker while (ignored_span_boundary_codepoints_.find(*it) !=
529*993b0882SAndroid Build Coastguard Worker ignored_span_boundary_codepoints_.end()) {
530*993b0882SAndroid Build Coastguard Worker ++num_ignored;
531*993b0882SAndroid Build Coastguard Worker
532*993b0882SAndroid Build Coastguard Worker if (it == it_last) {
533*993b0882SAndroid Build Coastguard Worker break;
534*993b0882SAndroid Build Coastguard Worker }
535*993b0882SAndroid Build Coastguard Worker
536*993b0882SAndroid Build Coastguard Worker if (count_from_beginning) {
537*993b0882SAndroid Build Coastguard Worker ++it;
538*993b0882SAndroid Build Coastguard Worker } else {
539*993b0882SAndroid Build Coastguard Worker --it;
540*993b0882SAndroid Build Coastguard Worker }
541*993b0882SAndroid Build Coastguard Worker }
542*993b0882SAndroid Build Coastguard Worker
543*993b0882SAndroid Build Coastguard Worker return num_ignored;
544*993b0882SAndroid Build Coastguard Worker }
545*993b0882SAndroid Build Coastguard Worker
546*993b0882SAndroid Build Coastguard Worker namespace {
547*993b0882SAndroid Build Coastguard Worker
FindSubstrings(const UnicodeText & t,const std::set<char32> & codepoints,std::vector<UnicodeTextRange> * ranges)548*993b0882SAndroid Build Coastguard Worker void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
549*993b0882SAndroid Build Coastguard Worker std::vector<UnicodeTextRange>* ranges) {
550*993b0882SAndroid Build Coastguard Worker UnicodeText::const_iterator start = t.begin();
551*993b0882SAndroid Build Coastguard Worker UnicodeText::const_iterator curr = start;
552*993b0882SAndroid Build Coastguard Worker UnicodeText::const_iterator end = t.end();
553*993b0882SAndroid Build Coastguard Worker for (; curr != end; ++curr) {
554*993b0882SAndroid Build Coastguard Worker if (codepoints.find(*curr) != codepoints.end()) {
555*993b0882SAndroid Build Coastguard Worker if (start != curr) {
556*993b0882SAndroid Build Coastguard Worker ranges->push_back(std::make_pair(start, curr));
557*993b0882SAndroid Build Coastguard Worker }
558*993b0882SAndroid Build Coastguard Worker start = curr;
559*993b0882SAndroid Build Coastguard Worker ++start;
560*993b0882SAndroid Build Coastguard Worker }
561*993b0882SAndroid Build Coastguard Worker }
562*993b0882SAndroid Build Coastguard Worker if (start != end) {
563*993b0882SAndroid Build Coastguard Worker ranges->push_back(std::make_pair(start, end));
564*993b0882SAndroid Build Coastguard Worker }
565*993b0882SAndroid Build Coastguard Worker }
566*993b0882SAndroid Build Coastguard Worker
567*993b0882SAndroid Build Coastguard Worker } // namespace
568*993b0882SAndroid Build Coastguard Worker
SplitContext(const UnicodeText & context_unicode,const bool use_pipe_character_for_newline) const569*993b0882SAndroid Build Coastguard Worker std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
570*993b0882SAndroid Build Coastguard Worker const UnicodeText& context_unicode,
571*993b0882SAndroid Build Coastguard Worker const bool use_pipe_character_for_newline) const {
572*993b0882SAndroid Build Coastguard Worker std::vector<UnicodeTextRange> lines;
573*993b0882SAndroid Build Coastguard Worker std::set<char32> codepoints{'\n'};
574*993b0882SAndroid Build Coastguard Worker if (use_pipe_character_for_newline) {
575*993b0882SAndroid Build Coastguard Worker codepoints.insert('|');
576*993b0882SAndroid Build Coastguard Worker }
577*993b0882SAndroid Build Coastguard Worker FindSubstrings(context_unicode, codepoints, &lines);
578*993b0882SAndroid Build Coastguard Worker return lines;
579*993b0882SAndroid Build Coastguard Worker }
580*993b0882SAndroid Build Coastguard Worker
StripBoundaryCodepoints(const std::string & context,const CodepointSpan & span) const581*993b0882SAndroid Build Coastguard Worker CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
582*993b0882SAndroid Build Coastguard Worker const std::string& context, const CodepointSpan& span) const {
583*993b0882SAndroid Build Coastguard Worker const UnicodeText context_unicode =
584*993b0882SAndroid Build Coastguard Worker UTF8ToUnicodeText(context, /*do_copy=*/false);
585*993b0882SAndroid Build Coastguard Worker return StripBoundaryCodepoints(context_unicode, span);
586*993b0882SAndroid Build Coastguard Worker }
587*993b0882SAndroid Build Coastguard Worker
StripBoundaryCodepoints(const UnicodeText & context_unicode,const CodepointSpan & span) const588*993b0882SAndroid Build Coastguard Worker CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
589*993b0882SAndroid Build Coastguard Worker const UnicodeText& context_unicode, const CodepointSpan& span) const {
590*993b0882SAndroid Build Coastguard Worker if (context_unicode.empty() || !span.IsValid() || span.IsEmpty()) {
591*993b0882SAndroid Build Coastguard Worker return span;
592*993b0882SAndroid Build Coastguard Worker }
593*993b0882SAndroid Build Coastguard Worker
594*993b0882SAndroid Build Coastguard Worker const auto [span_begin, span_end] =
595*993b0882SAndroid Build Coastguard Worker CodepointSpanToUnicodeTextRange(context_unicode, span);
596*993b0882SAndroid Build Coastguard Worker
597*993b0882SAndroid Build Coastguard Worker return StripBoundaryCodepoints(span_begin, span_end, span);
598*993b0882SAndroid Build Coastguard Worker }
599*993b0882SAndroid Build Coastguard Worker
StripBoundaryCodepoints(const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const CodepointSpan & span) const600*993b0882SAndroid Build Coastguard Worker CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
601*993b0882SAndroid Build Coastguard Worker const UnicodeText::const_iterator& span_begin,
602*993b0882SAndroid Build Coastguard Worker const UnicodeText::const_iterator& span_end,
603*993b0882SAndroid Build Coastguard Worker const CodepointSpan& span) const {
604*993b0882SAndroid Build Coastguard Worker if (!span.IsValid() || span.IsEmpty() || span_begin == span_end) {
605*993b0882SAndroid Build Coastguard Worker return span;
606*993b0882SAndroid Build Coastguard Worker }
607*993b0882SAndroid Build Coastguard Worker
608*993b0882SAndroid Build Coastguard Worker const int start_offset = CountIgnoredSpanBoundaryCodepoints(
609*993b0882SAndroid Build Coastguard Worker span_begin, span_end, /*count_from_beginning=*/true);
610*993b0882SAndroid Build Coastguard Worker const int end_offset = CountIgnoredSpanBoundaryCodepoints(
611*993b0882SAndroid Build Coastguard Worker span_begin, span_end, /*count_from_beginning=*/false);
612*993b0882SAndroid Build Coastguard Worker
613*993b0882SAndroid Build Coastguard Worker if (span.first + start_offset < span.second - end_offset) {
614*993b0882SAndroid Build Coastguard Worker return {span.first + start_offset, span.second - end_offset};
615*993b0882SAndroid Build Coastguard Worker } else {
616*993b0882SAndroid Build Coastguard Worker return {span.first, span.first};
617*993b0882SAndroid Build Coastguard Worker }
618*993b0882SAndroid Build Coastguard Worker }
619*993b0882SAndroid Build Coastguard Worker
SupportedCodepointsRatio(const TokenSpan & token_span,const std::vector<Token> & tokens) const620*993b0882SAndroid Build Coastguard Worker float FeatureProcessor::SupportedCodepointsRatio(
621*993b0882SAndroid Build Coastguard Worker const TokenSpan& token_span, const std::vector<Token>& tokens) const {
622*993b0882SAndroid Build Coastguard Worker int num_supported = 0;
623*993b0882SAndroid Build Coastguard Worker int num_total = 0;
624*993b0882SAndroid Build Coastguard Worker for (int i = token_span.first; i < token_span.second; ++i) {
625*993b0882SAndroid Build Coastguard Worker const UnicodeText value =
626*993b0882SAndroid Build Coastguard Worker UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
627*993b0882SAndroid Build Coastguard Worker for (auto codepoint : value) {
628*993b0882SAndroid Build Coastguard Worker if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
629*993b0882SAndroid Build Coastguard Worker ++num_supported;
630*993b0882SAndroid Build Coastguard Worker }
631*993b0882SAndroid Build Coastguard Worker ++num_total;
632*993b0882SAndroid Build Coastguard Worker }
633*993b0882SAndroid Build Coastguard Worker }
634*993b0882SAndroid Build Coastguard Worker // Avoid division by zero.
635*993b0882SAndroid Build Coastguard Worker if (num_total == 0) {
636*993b0882SAndroid Build Coastguard Worker return 0.0;
637*993b0882SAndroid Build Coastguard Worker }
638*993b0882SAndroid Build Coastguard Worker return static_cast<float>(num_supported) / static_cast<float>(num_total);
639*993b0882SAndroid Build Coastguard Worker }
640*993b0882SAndroid Build Coastguard Worker
StripBoundaryCodepoints(const std::string & value,std::string * buffer) const641*993b0882SAndroid Build Coastguard Worker const std::string& FeatureProcessor::StripBoundaryCodepoints(
642*993b0882SAndroid Build Coastguard Worker const std::string& value, std::string* buffer) const {
643*993b0882SAndroid Build Coastguard Worker const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false);
644*993b0882SAndroid Build Coastguard Worker const CodepointSpan initial_span{0, value_unicode.size_codepoints()};
645*993b0882SAndroid Build Coastguard Worker const CodepointSpan stripped_span =
646*993b0882SAndroid Build Coastguard Worker StripBoundaryCodepoints(value_unicode, initial_span);
647*993b0882SAndroid Build Coastguard Worker
648*993b0882SAndroid Build Coastguard Worker if (initial_span != stripped_span) {
649*993b0882SAndroid Build Coastguard Worker const UnicodeText stripped_token_value =
650*993b0882SAndroid Build Coastguard Worker UnicodeText::Substring(value_unicode, stripped_span.first,
651*993b0882SAndroid Build Coastguard Worker stripped_span.second, /*do_copy=*/false);
652*993b0882SAndroid Build Coastguard Worker *buffer = stripped_token_value.ToUTF8String();
653*993b0882SAndroid Build Coastguard Worker return *buffer;
654*993b0882SAndroid Build Coastguard Worker }
655*993b0882SAndroid Build Coastguard Worker return value;
656*993b0882SAndroid Build Coastguard Worker }
657*993b0882SAndroid Build Coastguard Worker
CollectionToLabel(const std::string & collection) const658*993b0882SAndroid Build Coastguard Worker int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
659*993b0882SAndroid Build Coastguard Worker const auto it = collection_to_label_.find(collection);
660*993b0882SAndroid Build Coastguard Worker if (it == collection_to_label_.end()) {
661*993b0882SAndroid Build Coastguard Worker return options_->default_collection();
662*993b0882SAndroid Build Coastguard Worker } else {
663*993b0882SAndroid Build Coastguard Worker return it->second;
664*993b0882SAndroid Build Coastguard Worker }
665*993b0882SAndroid Build Coastguard Worker }
666*993b0882SAndroid Build Coastguard Worker
LabelToCollection(int label) const667*993b0882SAndroid Build Coastguard Worker std::string FeatureProcessor::LabelToCollection(int label) const {
668*993b0882SAndroid Build Coastguard Worker if (label >= 0 && label < collection_to_label_.size()) {
669*993b0882SAndroid Build Coastguard Worker return (*options_->collections())[label]->str();
670*993b0882SAndroid Build Coastguard Worker } else {
671*993b0882SAndroid Build Coastguard Worker return GetDefaultCollection();
672*993b0882SAndroid Build Coastguard Worker }
673*993b0882SAndroid Build Coastguard Worker }
674*993b0882SAndroid Build Coastguard Worker
MakeLabelMaps()675*993b0882SAndroid Build Coastguard Worker void FeatureProcessor::MakeLabelMaps() {
676*993b0882SAndroid Build Coastguard Worker if (options_->collections() != nullptr) {
677*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < options_->collections()->size(); ++i) {
678*993b0882SAndroid Build Coastguard Worker collection_to_label_[(*options_->collections())[i]->str()] = i;
679*993b0882SAndroid Build Coastguard Worker }
680*993b0882SAndroid Build Coastguard Worker }
681*993b0882SAndroid Build Coastguard Worker
682*993b0882SAndroid Build Coastguard Worker int selection_label_id = 0;
683*993b0882SAndroid Build Coastguard Worker for (int l = 0; l < (options_->max_selection_span() + 1); ++l) {
684*993b0882SAndroid Build Coastguard Worker for (int r = 0; r < (options_->max_selection_span() + 1); ++r) {
685*993b0882SAndroid Build Coastguard Worker if (!options_->selection_reduced_output_space() ||
686*993b0882SAndroid Build Coastguard Worker r + l <= options_->max_selection_span()) {
687*993b0882SAndroid Build Coastguard Worker TokenSpan token_span{l, r};
688*993b0882SAndroid Build Coastguard Worker selection_to_label_[token_span] = selection_label_id;
689*993b0882SAndroid Build Coastguard Worker label_to_selection_.push_back(token_span);
690*993b0882SAndroid Build Coastguard Worker ++selection_label_id;
691*993b0882SAndroid Build Coastguard Worker }
692*993b0882SAndroid Build Coastguard Worker }
693*993b0882SAndroid Build Coastguard Worker }
694*993b0882SAndroid Build Coastguard Worker }
695*993b0882SAndroid Build Coastguard Worker
RetokenizeAndFindClick(const std::string & context,const CodepointSpan & input_span,bool only_use_line_with_click,std::vector<Token> * tokens,int * click_pos) const696*993b0882SAndroid Build Coastguard Worker void FeatureProcessor::RetokenizeAndFindClick(const std::string& context,
697*993b0882SAndroid Build Coastguard Worker const CodepointSpan& input_span,
698*993b0882SAndroid Build Coastguard Worker bool only_use_line_with_click,
699*993b0882SAndroid Build Coastguard Worker std::vector<Token>* tokens,
700*993b0882SAndroid Build Coastguard Worker int* click_pos) const {
701*993b0882SAndroid Build Coastguard Worker const UnicodeText context_unicode =
702*993b0882SAndroid Build Coastguard Worker UTF8ToUnicodeText(context, /*do_copy=*/false);
703*993b0882SAndroid Build Coastguard Worker const auto [span_begin, span_end] =
704*993b0882SAndroid Build Coastguard Worker CodepointSpanToUnicodeTextRange(context_unicode, input_span);
705*993b0882SAndroid Build Coastguard Worker RetokenizeAndFindClick(context_unicode, span_begin, span_end, input_span,
706*993b0882SAndroid Build Coastguard Worker only_use_line_with_click, tokens, click_pos);
707*993b0882SAndroid Build Coastguard Worker }
708*993b0882SAndroid Build Coastguard Worker
RetokenizeAndFindClick(const UnicodeText & context_unicode,const UnicodeText::const_iterator & span_begin,const UnicodeText::const_iterator & span_end,const CodepointSpan & input_span,bool only_use_line_with_click,std::vector<Token> * tokens,int * click_pos) const709*993b0882SAndroid Build Coastguard Worker void FeatureProcessor::RetokenizeAndFindClick(
710*993b0882SAndroid Build Coastguard Worker const UnicodeText& context_unicode,
711*993b0882SAndroid Build Coastguard Worker const UnicodeText::const_iterator& span_begin,
712*993b0882SAndroid Build Coastguard Worker const UnicodeText::const_iterator& span_end,
713*993b0882SAndroid Build Coastguard Worker const CodepointSpan& input_span, bool only_use_line_with_click,
714*993b0882SAndroid Build Coastguard Worker std::vector<Token>* tokens, int* click_pos) const {
715*993b0882SAndroid Build Coastguard Worker TC3_CHECK(tokens != nullptr);
716*993b0882SAndroid Build Coastguard Worker
717*993b0882SAndroid Build Coastguard Worker if (options_->split_tokens_on_selection_boundaries()) {
718*993b0882SAndroid Build Coastguard Worker internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
719*993b0882SAndroid Build Coastguard Worker }
720*993b0882SAndroid Build Coastguard Worker
721*993b0882SAndroid Build Coastguard Worker if (only_use_line_with_click) {
722*993b0882SAndroid Build Coastguard Worker StripTokensFromOtherLines(context_unicode, span_begin, span_end, input_span,
723*993b0882SAndroid Build Coastguard Worker tokens);
724*993b0882SAndroid Build Coastguard Worker }
725*993b0882SAndroid Build Coastguard Worker
726*993b0882SAndroid Build Coastguard Worker int local_click_pos;
727*993b0882SAndroid Build Coastguard Worker if (click_pos == nullptr) {
728*993b0882SAndroid Build Coastguard Worker click_pos = &local_click_pos;
729*993b0882SAndroid Build Coastguard Worker }
730*993b0882SAndroid Build Coastguard Worker *click_pos = FindCenterToken(input_span, *tokens);
731*993b0882SAndroid Build Coastguard Worker if (*click_pos == kInvalidIndex) {
732*993b0882SAndroid Build Coastguard Worker // If the default click method failed, let's try to do sub-token matching
733*993b0882SAndroid Build Coastguard Worker // before we fail.
734*993b0882SAndroid Build Coastguard Worker *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
735*993b0882SAndroid Build Coastguard Worker }
736*993b0882SAndroid Build Coastguard Worker }
737*993b0882SAndroid Build Coastguard Worker
738*993b0882SAndroid Build Coastguard Worker namespace internal {
739*993b0882SAndroid Build Coastguard Worker
StripOrPadTokens(const TokenSpan & relative_click_span,int context_size,std::vector<Token> * tokens,int * click_pos)740*993b0882SAndroid Build Coastguard Worker void StripOrPadTokens(const TokenSpan& relative_click_span, int context_size,
741*993b0882SAndroid Build Coastguard Worker std::vector<Token>* tokens, int* click_pos) {
742*993b0882SAndroid Build Coastguard Worker int right_context_needed = relative_click_span.second + context_size;
743*993b0882SAndroid Build Coastguard Worker if (*click_pos + right_context_needed + 1 >= tokens->size()) {
744*993b0882SAndroid Build Coastguard Worker // Pad max the context size.
745*993b0882SAndroid Build Coastguard Worker const int num_pad_tokens = std::min(
746*993b0882SAndroid Build Coastguard Worker context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
747*993b0882SAndroid Build Coastguard Worker tokens->size()));
748*993b0882SAndroid Build Coastguard Worker std::vector<Token> pad_tokens(num_pad_tokens);
749*993b0882SAndroid Build Coastguard Worker tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
750*993b0882SAndroid Build Coastguard Worker } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
751*993b0882SAndroid Build Coastguard Worker // Strip unused tokens.
752*993b0882SAndroid Build Coastguard Worker auto it = tokens->begin();
753*993b0882SAndroid Build Coastguard Worker std::advance(it, *click_pos + right_context_needed + 1);
754*993b0882SAndroid Build Coastguard Worker tokens->erase(it, tokens->end());
755*993b0882SAndroid Build Coastguard Worker }
756*993b0882SAndroid Build Coastguard Worker
757*993b0882SAndroid Build Coastguard Worker int left_context_needed = relative_click_span.first + context_size;
758*993b0882SAndroid Build Coastguard Worker if (*click_pos < left_context_needed) {
759*993b0882SAndroid Build Coastguard Worker // Pad max the context size.
760*993b0882SAndroid Build Coastguard Worker const int num_pad_tokens =
761*993b0882SAndroid Build Coastguard Worker std::min(context_size, left_context_needed - *click_pos);
762*993b0882SAndroid Build Coastguard Worker std::vector<Token> pad_tokens(num_pad_tokens);
763*993b0882SAndroid Build Coastguard Worker tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
764*993b0882SAndroid Build Coastguard Worker *click_pos += num_pad_tokens;
765*993b0882SAndroid Build Coastguard Worker } else if (*click_pos > left_context_needed) {
766*993b0882SAndroid Build Coastguard Worker // Strip unused tokens.
767*993b0882SAndroid Build Coastguard Worker auto it = tokens->begin();
768*993b0882SAndroid Build Coastguard Worker std::advance(it, *click_pos - left_context_needed);
769*993b0882SAndroid Build Coastguard Worker *click_pos -= it - tokens->begin();
770*993b0882SAndroid Build Coastguard Worker tokens->erase(tokens->begin(), it);
771*993b0882SAndroid Build Coastguard Worker }
772*993b0882SAndroid Build Coastguard Worker }
773*993b0882SAndroid Build Coastguard Worker
774*993b0882SAndroid Build Coastguard Worker } // namespace internal
775*993b0882SAndroid Build Coastguard Worker
HasEnoughSupportedCodepoints(const std::vector<Token> & tokens,const TokenSpan & token_span) const776*993b0882SAndroid Build Coastguard Worker bool FeatureProcessor::HasEnoughSupportedCodepoints(
777*993b0882SAndroid Build Coastguard Worker const std::vector<Token>& tokens, const TokenSpan& token_span) const {
778*993b0882SAndroid Build Coastguard Worker if (options_->min_supported_codepoint_ratio() > 0) {
779*993b0882SAndroid Build Coastguard Worker const float supported_codepoint_ratio =
780*993b0882SAndroid Build Coastguard Worker SupportedCodepointsRatio(token_span, tokens);
781*993b0882SAndroid Build Coastguard Worker if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) {
782*993b0882SAndroid Build Coastguard Worker TC3_VLOG(1) << "Not enough supported codepoints in the context: "
783*993b0882SAndroid Build Coastguard Worker << supported_codepoint_ratio;
784*993b0882SAndroid Build Coastguard Worker return false;
785*993b0882SAndroid Build Coastguard Worker }
786*993b0882SAndroid Build Coastguard Worker }
787*993b0882SAndroid Build Coastguard Worker return true;
788*993b0882SAndroid Build Coastguard Worker }
789*993b0882SAndroid Build Coastguard Worker
ExtractFeatures(const std::vector<Token> & tokens,const TokenSpan & token_span,const CodepointSpan & selection_span_for_feature,const EmbeddingExecutor * embedding_executor,EmbeddingCache * embedding_cache,int feature_vector_size,std::unique_ptr<CachedFeatures> * cached_features) const790*993b0882SAndroid Build Coastguard Worker bool FeatureProcessor::ExtractFeatures(
791*993b0882SAndroid Build Coastguard Worker const std::vector<Token>& tokens, const TokenSpan& token_span,
792*993b0882SAndroid Build Coastguard Worker const CodepointSpan& selection_span_for_feature,
793*993b0882SAndroid Build Coastguard Worker const EmbeddingExecutor* embedding_executor,
794*993b0882SAndroid Build Coastguard Worker EmbeddingCache* embedding_cache, int feature_vector_size,
795*993b0882SAndroid Build Coastguard Worker std::unique_ptr<CachedFeatures>* cached_features) const {
796*993b0882SAndroid Build Coastguard Worker std::unique_ptr<std::vector<float>> features(new std::vector<float>());
797*993b0882SAndroid Build Coastguard Worker features->reserve(feature_vector_size * token_span.Size());
798*993b0882SAndroid Build Coastguard Worker for (int i = token_span.first; i < token_span.second; ++i) {
799*993b0882SAndroid Build Coastguard Worker if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature,
800*993b0882SAndroid Build Coastguard Worker embedding_executor, embedding_cache,
801*993b0882SAndroid Build Coastguard Worker features.get())) {
802*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not get token features.";
803*993b0882SAndroid Build Coastguard Worker return false;
804*993b0882SAndroid Build Coastguard Worker }
805*993b0882SAndroid Build Coastguard Worker }
806*993b0882SAndroid Build Coastguard Worker
807*993b0882SAndroid Build Coastguard Worker std::unique_ptr<std::vector<float>> padding_features(
808*993b0882SAndroid Build Coastguard Worker new std::vector<float>());
809*993b0882SAndroid Build Coastguard Worker padding_features->reserve(feature_vector_size);
810*993b0882SAndroid Build Coastguard Worker if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature,
811*993b0882SAndroid Build Coastguard Worker embedding_executor, embedding_cache,
812*993b0882SAndroid Build Coastguard Worker padding_features.get())) {
813*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Count not get padding token features.";
814*993b0882SAndroid Build Coastguard Worker return false;
815*993b0882SAndroid Build Coastguard Worker }
816*993b0882SAndroid Build Coastguard Worker
817*993b0882SAndroid Build Coastguard Worker *cached_features = CachedFeatures::Create(token_span, std::move(features),
818*993b0882SAndroid Build Coastguard Worker std::move(padding_features),
819*993b0882SAndroid Build Coastguard Worker options_, feature_vector_size);
820*993b0882SAndroid Build Coastguard Worker if (!*cached_features) {
821*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Cound not create cached features.";
822*993b0882SAndroid Build Coastguard Worker return false;
823*993b0882SAndroid Build Coastguard Worker }
824*993b0882SAndroid Build Coastguard Worker
825*993b0882SAndroid Build Coastguard Worker return true;
826*993b0882SAndroid Build Coastguard Worker }
827*993b0882SAndroid Build Coastguard Worker
AppendTokenFeaturesWithCache(const Token & token,const CodepointSpan & selection_span_for_feature,const EmbeddingExecutor * embedding_executor,EmbeddingCache * embedding_cache,std::vector<float> * output_features) const828*993b0882SAndroid Build Coastguard Worker bool FeatureProcessor::AppendTokenFeaturesWithCache(
829*993b0882SAndroid Build Coastguard Worker const Token& token, const CodepointSpan& selection_span_for_feature,
830*993b0882SAndroid Build Coastguard Worker const EmbeddingExecutor* embedding_executor,
831*993b0882SAndroid Build Coastguard Worker EmbeddingCache* embedding_cache,
832*993b0882SAndroid Build Coastguard Worker std::vector<float>* output_features) const {
833*993b0882SAndroid Build Coastguard Worker // Look for the embedded features for the token in the cache, if there is one.
834*993b0882SAndroid Build Coastguard Worker if (embedding_cache) {
835*993b0882SAndroid Build Coastguard Worker const auto it = embedding_cache->find({token.start, token.end});
836*993b0882SAndroid Build Coastguard Worker if (it != embedding_cache->end()) {
837*993b0882SAndroid Build Coastguard Worker // The embedded features were found in the cache, extract only the dense
838*993b0882SAndroid Build Coastguard Worker // features.
839*993b0882SAndroid Build Coastguard Worker std::vector<float> dense_features;
840*993b0882SAndroid Build Coastguard Worker if (!feature_extractor_.Extract(
841*993b0882SAndroid Build Coastguard Worker token, token.IsContainedInSpan(selection_span_for_feature),
842*993b0882SAndroid Build Coastguard Worker /*sparse_features=*/nullptr, &dense_features)) {
843*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not extract token's dense features.";
844*993b0882SAndroid Build Coastguard Worker return false;
845*993b0882SAndroid Build Coastguard Worker }
846*993b0882SAndroid Build Coastguard Worker
847*993b0882SAndroid Build Coastguard Worker // Append both embedded and dense features to the output and return.
848*993b0882SAndroid Build Coastguard Worker output_features->insert(output_features->end(), it->second.begin(),
849*993b0882SAndroid Build Coastguard Worker it->second.end());
850*993b0882SAndroid Build Coastguard Worker output_features->insert(output_features->end(), dense_features.begin(),
851*993b0882SAndroid Build Coastguard Worker dense_features.end());
852*993b0882SAndroid Build Coastguard Worker return true;
853*993b0882SAndroid Build Coastguard Worker }
854*993b0882SAndroid Build Coastguard Worker }
855*993b0882SAndroid Build Coastguard Worker
856*993b0882SAndroid Build Coastguard Worker // Extract the sparse and dense features.
857*993b0882SAndroid Build Coastguard Worker std::vector<int> sparse_features;
858*993b0882SAndroid Build Coastguard Worker std::vector<float> dense_features;
859*993b0882SAndroid Build Coastguard Worker if (!feature_extractor_.Extract(
860*993b0882SAndroid Build Coastguard Worker token, token.IsContainedInSpan(selection_span_for_feature),
861*993b0882SAndroid Build Coastguard Worker &sparse_features, &dense_features)) {
862*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Could not extract token's features.";
863*993b0882SAndroid Build Coastguard Worker return false;
864*993b0882SAndroid Build Coastguard Worker }
865*993b0882SAndroid Build Coastguard Worker
866*993b0882SAndroid Build Coastguard Worker // Embed the sparse features, appending them directly to the output.
867*993b0882SAndroid Build Coastguard Worker const int embedding_size = GetOptions()->embedding_size();
868*993b0882SAndroid Build Coastguard Worker output_features->resize(output_features->size() + embedding_size);
869*993b0882SAndroid Build Coastguard Worker float* output_features_end =
870*993b0882SAndroid Build Coastguard Worker output_features->data() + output_features->size();
871*993b0882SAndroid Build Coastguard Worker if (!embedding_executor->AddEmbedding(
872*993b0882SAndroid Build Coastguard Worker TensorView<int>(sparse_features.data(),
873*993b0882SAndroid Build Coastguard Worker {static_cast<int>(sparse_features.size())}),
874*993b0882SAndroid Build Coastguard Worker /*dest=*/output_features_end - embedding_size,
875*993b0882SAndroid Build Coastguard Worker /*dest_size=*/embedding_size)) {
876*993b0882SAndroid Build Coastguard Worker TC3_LOG(ERROR) << "Cound not embed token's sparse features.";
877*993b0882SAndroid Build Coastguard Worker return false;
878*993b0882SAndroid Build Coastguard Worker }
879*993b0882SAndroid Build Coastguard Worker
880*993b0882SAndroid Build Coastguard Worker // If there is a cache, the embedded features for the token were not in it,
881*993b0882SAndroid Build Coastguard Worker // so insert them.
882*993b0882SAndroid Build Coastguard Worker if (embedding_cache) {
883*993b0882SAndroid Build Coastguard Worker (*embedding_cache)[{token.start, token.end}] = std::vector<float>(
884*993b0882SAndroid Build Coastguard Worker output_features_end - embedding_size, output_features_end);
885*993b0882SAndroid Build Coastguard Worker }
886*993b0882SAndroid Build Coastguard Worker
887*993b0882SAndroid Build Coastguard Worker // Append the dense features to the output.
888*993b0882SAndroid Build Coastguard Worker output_features->insert(output_features->end(), dense_features.begin(),
889*993b0882SAndroid Build Coastguard Worker dense_features.end());
890*993b0882SAndroid Build Coastguard Worker return true;
891*993b0882SAndroid Build Coastguard Worker }
892*993b0882SAndroid Build Coastguard Worker
893*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
894