xref: /aosp_15_r20/external/libtextclassifier/native/annotator/feature-processor.h (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker // Feature processing for FFModel (feed-forward SmartSelection model).
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
20*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker #include <map>
23*993b0882SAndroid Build Coastguard Worker #include <memory>
24*993b0882SAndroid Build Coastguard Worker #include <set>
25*993b0882SAndroid Build Coastguard Worker #include <string>
26*993b0882SAndroid Build Coastguard Worker #include <vector>
27*993b0882SAndroid Build Coastguard Worker 
28*993b0882SAndroid Build Coastguard Worker #include "annotator/cached-features.h"
29*993b0882SAndroid Build Coastguard Worker #include "annotator/model_generated.h"
30*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
31*993b0882SAndroid Build Coastguard Worker #include "utils/base/integral_types.h"
32*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
33*993b0882SAndroid Build Coastguard Worker #include "utils/token-feature-extractor.h"
34*993b0882SAndroid Build Coastguard Worker #include "utils/tokenizer.h"
35*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
36*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unilib.h"
37*993b0882SAndroid Build Coastguard Worker 
38*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
39*993b0882SAndroid Build Coastguard Worker 
40*993b0882SAndroid Build Coastguard Worker constexpr int kInvalidLabel = -1;
41*993b0882SAndroid Build Coastguard Worker 
42*993b0882SAndroid Build Coastguard Worker namespace internal {
43*993b0882SAndroid Build Coastguard Worker 
44*993b0882SAndroid Build Coastguard Worker Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
45*993b0882SAndroid Build Coastguard Worker                          const UniLib* unilib);
46*993b0882SAndroid Build Coastguard Worker 
47*993b0882SAndroid Build Coastguard Worker TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
48*993b0882SAndroid Build Coastguard Worker     const FeatureProcessorOptions* options);
49*993b0882SAndroid Build Coastguard Worker 
50*993b0882SAndroid Build Coastguard Worker // Splits tokens that contain the selection boundary inside them.
51*993b0882SAndroid Build Coastguard Worker // E.g. "foo{bar}@google.com" -> "foo", "bar", "@google.com"
52*993b0882SAndroid Build Coastguard Worker void SplitTokensOnSelectionBoundaries(const CodepointSpan& selection,
53*993b0882SAndroid Build Coastguard Worker                                       std::vector<Token>* tokens);
54*993b0882SAndroid Build Coastguard Worker 
55*993b0882SAndroid Build Coastguard Worker // Returns the index of token that corresponds to the codepoint span.
56*993b0882SAndroid Build Coastguard Worker int CenterTokenFromClick(const CodepointSpan& span,
57*993b0882SAndroid Build Coastguard Worker                          const std::vector<Token>& tokens);
58*993b0882SAndroid Build Coastguard Worker 
59*993b0882SAndroid Build Coastguard Worker // Returns the index of token that corresponds to the middle of the  codepoint
60*993b0882SAndroid Build Coastguard Worker // span.
61*993b0882SAndroid Build Coastguard Worker int CenterTokenFromMiddleOfSelection(
62*993b0882SAndroid Build Coastguard Worker     const CodepointSpan& span, const std::vector<Token>& selectable_tokens);
63*993b0882SAndroid Build Coastguard Worker 
64*993b0882SAndroid Build Coastguard Worker // Strips the tokens from the tokens vector that are not used for feature
65*993b0882SAndroid Build Coastguard Worker // extraction because they are out of scope, or pads them so that there is
66*993b0882SAndroid Build Coastguard Worker // enough tokens in the required context_size for all inferences with a click
67*993b0882SAndroid Build Coastguard Worker // in relative_click_span.
68*993b0882SAndroid Build Coastguard Worker void StripOrPadTokens(const TokenSpan& relative_click_span, int context_size,
69*993b0882SAndroid Build Coastguard Worker                       std::vector<Token>* tokens, int* click_pos);
70*993b0882SAndroid Build Coastguard Worker 
71*993b0882SAndroid Build Coastguard Worker }  // namespace internal
72*993b0882SAndroid Build Coastguard Worker 
73*993b0882SAndroid Build Coastguard Worker // Converts a codepoint span to a token span in the given list of tokens.
74*993b0882SAndroid Build Coastguard Worker // If snap_boundaries_to_containing_tokens is set to true, it is enough for a
75*993b0882SAndroid Build Coastguard Worker // token to overlap with the codepoint range to be considered part of it.
76*993b0882SAndroid Build Coastguard Worker // Otherwise it must be fully included in the range.
77*993b0882SAndroid Build Coastguard Worker TokenSpan CodepointSpanToTokenSpan(
78*993b0882SAndroid Build Coastguard Worker     const std::vector<Token>& selectable_tokens,
79*993b0882SAndroid Build Coastguard Worker     const CodepointSpan& codepoint_span,
80*993b0882SAndroid Build Coastguard Worker     bool snap_boundaries_to_containing_tokens = false);
81*993b0882SAndroid Build Coastguard Worker 
82*993b0882SAndroid Build Coastguard Worker // Converts a token span to a codepoint span in the given list of tokens.
83*993b0882SAndroid Build Coastguard Worker CodepointSpan TokenSpanToCodepointSpan(
84*993b0882SAndroid Build Coastguard Worker     const std::vector<Token>& selectable_tokens, const TokenSpan& token_span);
85*993b0882SAndroid Build Coastguard Worker 
86*993b0882SAndroid Build Coastguard Worker // Converts a codepoint span to a unicode text range, within the given unicode
87*993b0882SAndroid Build Coastguard Worker // text.
88*993b0882SAndroid Build Coastguard Worker // For an invalid span (with a negative index), returns (begin, begin). This
89*993b0882SAndroid Build Coastguard Worker // means that it is safe to call this function before checking the validity of
90*993b0882SAndroid Build Coastguard Worker // the span.
91*993b0882SAndroid Build Coastguard Worker // The indices must fit within the unicode text.
92*993b0882SAndroid Build Coastguard Worker // Note that the execution time is linear with respect to the codepoint indices.
93*993b0882SAndroid Build Coastguard Worker // Calling this function repeatedly for spans on the same text might lead to
94*993b0882SAndroid Build Coastguard Worker // inefficient code.
95*993b0882SAndroid Build Coastguard Worker UnicodeTextRange CodepointSpanToUnicodeTextRange(
96*993b0882SAndroid Build Coastguard Worker     const UnicodeText& unicode_text, const CodepointSpan& span);
97*993b0882SAndroid Build Coastguard Worker 
98*993b0882SAndroid Build Coastguard Worker // Takes care of preparing features for the span prediction model.
99*993b0882SAndroid Build Coastguard Worker class FeatureProcessor {
100*993b0882SAndroid Build Coastguard Worker  public:
101*993b0882SAndroid Build Coastguard Worker   // A cache mapping codepoint spans to embedded tokens features. An instance
102*993b0882SAndroid Build Coastguard Worker   // can be provided to multiple calls to ExtractFeatures() operating on the
103*993b0882SAndroid Build Coastguard Worker   // same context (the same codepoint spans corresponding to the same tokens),
104*993b0882SAndroid Build Coastguard Worker   // as an optimization. Note that the tokenizations do not have to be
105*993b0882SAndroid Build Coastguard Worker   // identical.
106*993b0882SAndroid Build Coastguard Worker   typedef std::map<CodepointSpan, std::vector<float>> EmbeddingCache;
107*993b0882SAndroid Build Coastguard Worker 
FeatureProcessor(const FeatureProcessorOptions * options,const UniLib * unilib)108*993b0882SAndroid Build Coastguard Worker   explicit FeatureProcessor(const FeatureProcessorOptions* options,
109*993b0882SAndroid Build Coastguard Worker                             const UniLib* unilib)
110*993b0882SAndroid Build Coastguard Worker       : feature_extractor_(internal::BuildTokenFeatureExtractorOptions(options),
111*993b0882SAndroid Build Coastguard Worker                            unilib),
112*993b0882SAndroid Build Coastguard Worker         options_(options),
113*993b0882SAndroid Build Coastguard Worker         tokenizer_(internal::BuildTokenizer(options, unilib)) {
114*993b0882SAndroid Build Coastguard Worker     MakeLabelMaps();
115*993b0882SAndroid Build Coastguard Worker     if (options->supported_codepoint_ranges() != nullptr) {
116*993b0882SAndroid Build Coastguard Worker       SortCodepointRanges({options->supported_codepoint_ranges()->begin(),
117*993b0882SAndroid Build Coastguard Worker                            options->supported_codepoint_ranges()->end()},
118*993b0882SAndroid Build Coastguard Worker                           &supported_codepoint_ranges_);
119*993b0882SAndroid Build Coastguard Worker     }
120*993b0882SAndroid Build Coastguard Worker     PrepareIgnoredSpanBoundaryCodepoints();
121*993b0882SAndroid Build Coastguard Worker   }
122*993b0882SAndroid Build Coastguard Worker 
123*993b0882SAndroid Build Coastguard Worker   // Tokenizes the input string using the selected tokenization method.
124*993b0882SAndroid Build Coastguard Worker   std::vector<Token> Tokenize(const std::string& text) const;
125*993b0882SAndroid Build Coastguard Worker 
126*993b0882SAndroid Build Coastguard Worker   // Same as above but takes UnicodeText.
127*993b0882SAndroid Build Coastguard Worker   std::vector<Token> Tokenize(const UnicodeText& text_unicode) const;
128*993b0882SAndroid Build Coastguard Worker 
129*993b0882SAndroid Build Coastguard Worker   // Converts a label into a token span.
130*993b0882SAndroid Build Coastguard Worker   bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
131*993b0882SAndroid Build Coastguard Worker 
132*993b0882SAndroid Build Coastguard Worker   // Gets the total number of selection labels.
GetSelectionLabelCount()133*993b0882SAndroid Build Coastguard Worker   int GetSelectionLabelCount() const { return label_to_selection_.size(); }
134*993b0882SAndroid Build Coastguard Worker 
135*993b0882SAndroid Build Coastguard Worker   // Gets the string value for given collection label.
136*993b0882SAndroid Build Coastguard Worker   std::string LabelToCollection(int label) const;
137*993b0882SAndroid Build Coastguard Worker 
138*993b0882SAndroid Build Coastguard Worker   // Gets the total number of collections of the model.
NumCollections()139*993b0882SAndroid Build Coastguard Worker   int NumCollections() const { return collection_to_label_.size(); }
140*993b0882SAndroid Build Coastguard Worker 
141*993b0882SAndroid Build Coastguard Worker   // Gets the name of the default collection.
142*993b0882SAndroid Build Coastguard Worker   std::string GetDefaultCollection() const;
143*993b0882SAndroid Build Coastguard Worker 
GetOptions()144*993b0882SAndroid Build Coastguard Worker   const FeatureProcessorOptions* GetOptions() const { return options_; }
145*993b0882SAndroid Build Coastguard Worker 
146*993b0882SAndroid Build Coastguard Worker   // Retokenizes the context and input span, and finds the click position.
147*993b0882SAndroid Build Coastguard Worker   // Depending on the options, might modify tokens (split them or remove them).
148*993b0882SAndroid Build Coastguard Worker   void RetokenizeAndFindClick(const std::string& context,
149*993b0882SAndroid Build Coastguard Worker                               const CodepointSpan& input_span,
150*993b0882SAndroid Build Coastguard Worker                               bool only_use_line_with_click,
151*993b0882SAndroid Build Coastguard Worker                               std::vector<Token>* tokens, int* click_pos) const;
152*993b0882SAndroid Build Coastguard Worker 
153*993b0882SAndroid Build Coastguard Worker   // Same as above, but takes UnicodeText and iterators within it corresponding
154*993b0882SAndroid Build Coastguard Worker   // to input_span.
155*993b0882SAndroid Build Coastguard Worker   void RetokenizeAndFindClick(const UnicodeText& context_unicode,
156*993b0882SAndroid Build Coastguard Worker                               const UnicodeText::const_iterator& span_begin,
157*993b0882SAndroid Build Coastguard Worker                               const UnicodeText::const_iterator& span_end,
158*993b0882SAndroid Build Coastguard Worker                               const CodepointSpan& input_span,
159*993b0882SAndroid Build Coastguard Worker                               bool only_use_line_with_click,
160*993b0882SAndroid Build Coastguard Worker                               std::vector<Token>* tokens, int* click_pos) const;
161*993b0882SAndroid Build Coastguard Worker 
162*993b0882SAndroid Build Coastguard Worker   // Returns true if the token span has enough supported codepoints (as defined
163*993b0882SAndroid Build Coastguard Worker   // in the model config) or not and model should not run.
164*993b0882SAndroid Build Coastguard Worker   bool HasEnoughSupportedCodepoints(const std::vector<Token>& tokens,
165*993b0882SAndroid Build Coastguard Worker                                     const TokenSpan& token_span) const;
166*993b0882SAndroid Build Coastguard Worker 
167*993b0882SAndroid Build Coastguard Worker   // Extracts features as a CachedFeatures object that can be used for repeated
168*993b0882SAndroid Build Coastguard Worker   // inference over token spans in the given context.
169*993b0882SAndroid Build Coastguard Worker   bool ExtractFeatures(const std::vector<Token>& tokens,
170*993b0882SAndroid Build Coastguard Worker                        const TokenSpan& token_span,
171*993b0882SAndroid Build Coastguard Worker                        const CodepointSpan& selection_span_for_feature,
172*993b0882SAndroid Build Coastguard Worker                        const EmbeddingExecutor* embedding_executor,
173*993b0882SAndroid Build Coastguard Worker                        EmbeddingCache* embedding_cache, int feature_vector_size,
174*993b0882SAndroid Build Coastguard Worker                        std::unique_ptr<CachedFeatures>* cached_features) const;
175*993b0882SAndroid Build Coastguard Worker 
176*993b0882SAndroid Build Coastguard Worker   // Fills selection_label_spans with CodepointSpans that correspond to the
177*993b0882SAndroid Build Coastguard Worker   // selection labels. The CodepointSpans are based on the codepoint ranges of
178*993b0882SAndroid Build Coastguard Worker   // given tokens.
179*993b0882SAndroid Build Coastguard Worker   bool SelectionLabelSpans(
180*993b0882SAndroid Build Coastguard Worker       VectorSpan<Token> tokens,
181*993b0882SAndroid Build Coastguard Worker       std::vector<CodepointSpan>* selection_label_spans) const;
182*993b0882SAndroid Build Coastguard Worker 
183*993b0882SAndroid Build Coastguard Worker   // Fills selection_label_relative_token_spans with number of tokens left and
184*993b0882SAndroid Build Coastguard Worker   // right from the click.
185*993b0882SAndroid Build Coastguard Worker   bool SelectionLabelRelativeTokenSpans(
186*993b0882SAndroid Build Coastguard Worker       std::vector<TokenSpan>* selection_label_relative_token_spans) const;
187*993b0882SAndroid Build Coastguard Worker 
DenseFeaturesCount()188*993b0882SAndroid Build Coastguard Worker   int DenseFeaturesCount() const {
189*993b0882SAndroid Build Coastguard Worker     return feature_extractor_.DenseFeaturesCount();
190*993b0882SAndroid Build Coastguard Worker   }
191*993b0882SAndroid Build Coastguard Worker 
EmbeddingSize()192*993b0882SAndroid Build Coastguard Worker   int EmbeddingSize() const { return options_->embedding_size(); }
193*993b0882SAndroid Build Coastguard Worker 
194*993b0882SAndroid Build Coastguard Worker   // Splits context to several segments.
195*993b0882SAndroid Build Coastguard Worker   std::vector<UnicodeTextRange> SplitContext(
196*993b0882SAndroid Build Coastguard Worker       const UnicodeText& context_unicode,
197*993b0882SAndroid Build Coastguard Worker       const bool use_pipe_character_for_newline) const;
198*993b0882SAndroid Build Coastguard Worker 
199*993b0882SAndroid Build Coastguard Worker   // Strips boundary codepoints from the span in context and returns the new
200*993b0882SAndroid Build Coastguard Worker   // start and end indices. If the span comprises entirely of boundary
201*993b0882SAndroid Build Coastguard Worker   // codepoints, the first index of span is returned for both indices.
202*993b0882SAndroid Build Coastguard Worker   CodepointSpan StripBoundaryCodepoints(const std::string& context,
203*993b0882SAndroid Build Coastguard Worker                                         const CodepointSpan& span) const;
204*993b0882SAndroid Build Coastguard Worker 
205*993b0882SAndroid Build Coastguard Worker   // Same as above but takes UnicodeText.
206*993b0882SAndroid Build Coastguard Worker   CodepointSpan StripBoundaryCodepoints(const UnicodeText& context_unicode,
207*993b0882SAndroid Build Coastguard Worker                                         const CodepointSpan& span) const;
208*993b0882SAndroid Build Coastguard Worker 
209*993b0882SAndroid Build Coastguard Worker   // Same as above but takes a pair of iterators for the span, for efficiency.
210*993b0882SAndroid Build Coastguard Worker   CodepointSpan StripBoundaryCodepoints(
211*993b0882SAndroid Build Coastguard Worker       const UnicodeText::const_iterator& span_begin,
212*993b0882SAndroid Build Coastguard Worker       const UnicodeText::const_iterator& span_end,
213*993b0882SAndroid Build Coastguard Worker       const CodepointSpan& span) const;
214*993b0882SAndroid Build Coastguard Worker 
215*993b0882SAndroid Build Coastguard Worker   // Same as above, but takes an optional buffer for saving the modified value.
216*993b0882SAndroid Build Coastguard Worker   // As an optimization, returns pointer to 'value' if nothing was stripped, or
217*993b0882SAndroid Build Coastguard Worker   // pointer to 'buffer' if something was stripped.
218*993b0882SAndroid Build Coastguard Worker   const std::string& StripBoundaryCodepoints(const std::string& value,
219*993b0882SAndroid Build Coastguard Worker                                              std::string* buffer) const;
220*993b0882SAndroid Build Coastguard Worker 
221*993b0882SAndroid Build Coastguard Worker  protected:
222*993b0882SAndroid Build Coastguard Worker   // Returns the class id corresponding to the given string collection
223*993b0882SAndroid Build Coastguard Worker   // identifier. There is a catch-all class id that the function returns for
224*993b0882SAndroid Build Coastguard Worker   // unknown collections.
225*993b0882SAndroid Build Coastguard Worker   int CollectionToLabel(const std::string& collection) const;
226*993b0882SAndroid Build Coastguard Worker 
227*993b0882SAndroid Build Coastguard Worker   // Prepares mapping from collection names to labels.
228*993b0882SAndroid Build Coastguard Worker   void MakeLabelMaps();
229*993b0882SAndroid Build Coastguard Worker 
230*993b0882SAndroid Build Coastguard Worker   // Gets the number of spannable tokens for the model.
231*993b0882SAndroid Build Coastguard Worker   //
232*993b0882SAndroid Build Coastguard Worker   // Spannable tokens are those tokens of context, which the model predicts
233*993b0882SAndroid Build Coastguard Worker   // selection spans over (i.e., there is 1:1 correspondence between the output
234*993b0882SAndroid Build Coastguard Worker   // classes of the model and each of the spannable tokens).
GetNumContextTokens()235*993b0882SAndroid Build Coastguard Worker   int GetNumContextTokens() const { return options_->context_size() * 2 + 1; }
236*993b0882SAndroid Build Coastguard Worker 
237*993b0882SAndroid Build Coastguard Worker   // Converts a label into a span of codepoint indices corresponding to it
238*993b0882SAndroid Build Coastguard Worker   // given output_tokens.
239*993b0882SAndroid Build Coastguard Worker   bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
240*993b0882SAndroid Build Coastguard Worker                    CodepointSpan* span) const;
241*993b0882SAndroid Build Coastguard Worker 
242*993b0882SAndroid Build Coastguard Worker   // Converts a span to the corresponding label given output_tokens.
243*993b0882SAndroid Build Coastguard Worker   bool SpanToLabel(const CodepointSpan& span,
244*993b0882SAndroid Build Coastguard Worker                    const std::vector<Token>& output_tokens, int* label) const;
245*993b0882SAndroid Build Coastguard Worker 
246*993b0882SAndroid Build Coastguard Worker   // Converts a token span to the corresponding label.
247*993b0882SAndroid Build Coastguard Worker   int TokenSpanToLabel(const TokenSpan& token_span) const;
248*993b0882SAndroid Build Coastguard Worker 
249*993b0882SAndroid Build Coastguard Worker   // Returns the ratio of supported codepoints to total number of codepoints in
250*993b0882SAndroid Build Coastguard Worker   // the given token span.
251*993b0882SAndroid Build Coastguard Worker   float SupportedCodepointsRatio(const TokenSpan& token_span,
252*993b0882SAndroid Build Coastguard Worker                                  const std::vector<Token>& tokens) const;
253*993b0882SAndroid Build Coastguard Worker 
254*993b0882SAndroid Build Coastguard Worker   void PrepareIgnoredSpanBoundaryCodepoints();
255*993b0882SAndroid Build Coastguard Worker 
256*993b0882SAndroid Build Coastguard Worker   // Counts the number of span boundary codepoints. If count_from_beginning is
257*993b0882SAndroid Build Coastguard Worker   // True, the counting will start at the span_start iterator (inclusive) and at
258*993b0882SAndroid Build Coastguard Worker   // maximum end at span_end (exclusive). If count_from_beginning is True, the
259*993b0882SAndroid Build Coastguard Worker   // counting will start from span_end (exclusive) and end at span_start
260*993b0882SAndroid Build Coastguard Worker   // (inclusive).
261*993b0882SAndroid Build Coastguard Worker   int CountIgnoredSpanBoundaryCodepoints(
262*993b0882SAndroid Build Coastguard Worker       const UnicodeText::const_iterator& span_start,
263*993b0882SAndroid Build Coastguard Worker       const UnicodeText::const_iterator& span_end,
264*993b0882SAndroid Build Coastguard Worker       bool count_from_beginning) const;
265*993b0882SAndroid Build Coastguard Worker 
266*993b0882SAndroid Build Coastguard Worker   // Finds the center token index in tokens vector, using the method defined
267*993b0882SAndroid Build Coastguard Worker   // in options_.
268*993b0882SAndroid Build Coastguard Worker   int FindCenterToken(const CodepointSpan& span,
269*993b0882SAndroid Build Coastguard Worker                       const std::vector<Token>& tokens) const;
270*993b0882SAndroid Build Coastguard Worker 
271*993b0882SAndroid Build Coastguard Worker   // Removes all tokens from tokens that are not on a line (defined by calling
272*993b0882SAndroid Build Coastguard Worker   // SplitContext on the context) to which span points.
273*993b0882SAndroid Build Coastguard Worker   void StripTokensFromOtherLines(const std::string& context,
274*993b0882SAndroid Build Coastguard Worker                                  const CodepointSpan& span,
275*993b0882SAndroid Build Coastguard Worker                                  std::vector<Token>* tokens) const;
276*993b0882SAndroid Build Coastguard Worker 
277*993b0882SAndroid Build Coastguard Worker   // Same as above but takes UnicodeText.
278*993b0882SAndroid Build Coastguard Worker   void StripTokensFromOtherLines(const UnicodeText& context_unicode,
279*993b0882SAndroid Build Coastguard Worker                                  const UnicodeText::const_iterator& span_begin,
280*993b0882SAndroid Build Coastguard Worker                                  const UnicodeText::const_iterator& span_end,
281*993b0882SAndroid Build Coastguard Worker                                  const CodepointSpan& span,
282*993b0882SAndroid Build Coastguard Worker                                  std::vector<Token>* tokens) const;
283*993b0882SAndroid Build Coastguard Worker 
284*993b0882SAndroid Build Coastguard Worker   // Extracts the features of a token and appends them to the output vector.
285*993b0882SAndroid Build Coastguard Worker   // Uses the embedding cache to to avoid re-extracting the re-embedding the
286*993b0882SAndroid Build Coastguard Worker   // sparse features for the same token.
287*993b0882SAndroid Build Coastguard Worker   bool AppendTokenFeaturesWithCache(
288*993b0882SAndroid Build Coastguard Worker       const Token& token, const CodepointSpan& selection_span_for_feature,
289*993b0882SAndroid Build Coastguard Worker       const EmbeddingExecutor* embedding_executor,
290*993b0882SAndroid Build Coastguard Worker       EmbeddingCache* embedding_cache,
291*993b0882SAndroid Build Coastguard Worker       std::vector<float>* output_features) const;
292*993b0882SAndroid Build Coastguard Worker 
293*993b0882SAndroid Build Coastguard Worker  protected:
294*993b0882SAndroid Build Coastguard Worker   const TokenFeatureExtractor feature_extractor_;
295*993b0882SAndroid Build Coastguard Worker 
296*993b0882SAndroid Build Coastguard Worker   // Codepoint ranges that define what codepoints are supported by the model.
297*993b0882SAndroid Build Coastguard Worker   // NOTE: Must be sorted.
298*993b0882SAndroid Build Coastguard Worker   std::vector<CodepointRangeStruct> supported_codepoint_ranges_;
299*993b0882SAndroid Build Coastguard Worker 
300*993b0882SAndroid Build Coastguard Worker  private:
301*993b0882SAndroid Build Coastguard Worker   // Set of codepoints that will be stripped from beginning and end of
302*993b0882SAndroid Build Coastguard Worker   // predicted spans.
303*993b0882SAndroid Build Coastguard Worker   std::unordered_set<int32> ignored_span_boundary_codepoints_;
304*993b0882SAndroid Build Coastguard Worker 
305*993b0882SAndroid Build Coastguard Worker   const FeatureProcessorOptions* const options_;
306*993b0882SAndroid Build Coastguard Worker 
307*993b0882SAndroid Build Coastguard Worker   // Mapping between token selection spans and labels ids.
308*993b0882SAndroid Build Coastguard Worker   std::map<TokenSpan, int> selection_to_label_;
309*993b0882SAndroid Build Coastguard Worker   std::vector<TokenSpan> label_to_selection_;
310*993b0882SAndroid Build Coastguard Worker 
311*993b0882SAndroid Build Coastguard Worker   // Mapping between collections and labels.
312*993b0882SAndroid Build Coastguard Worker   std::map<std::string, int> collection_to_label_;
313*993b0882SAndroid Build Coastguard Worker 
314*993b0882SAndroid Build Coastguard Worker   Tokenizer tokenizer_;
315*993b0882SAndroid Build Coastguard Worker };
316*993b0882SAndroid Build Coastguard Worker 
317*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
318*993b0882SAndroid Build Coastguard Worker 
319*993b0882SAndroid Build Coastguard Worker #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_FEATURE_PROCESSOR_H_
320