xref: /aosp_15_r20/external/libtextclassifier/native/utils/token-feature-extractor.h (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_UTILS_TOKEN_FEATURE_EXTRACTOR_H_
18*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_UTILS_TOKEN_FEATURE_EXTRACTOR_H_
19*993b0882SAndroid Build Coastguard Worker 
20*993b0882SAndroid Build Coastguard Worker #include <memory>
21*993b0882SAndroid Build Coastguard Worker #include <unordered_set>
22*993b0882SAndroid Build Coastguard Worker #include <vector>
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/strings/stringpiece.h"
26*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unilib.h"
27*993b0882SAndroid Build Coastguard Worker 
28*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
29*993b0882SAndroid Build Coastguard Worker 
30*993b0882SAndroid Build Coastguard Worker struct TokenFeatureExtractorOptions {
31*993b0882SAndroid Build Coastguard Worker   // Number of buckets used for hashing charactergrams.
32*993b0882SAndroid Build Coastguard Worker   int num_buckets = 0;
33*993b0882SAndroid Build Coastguard Worker 
34*993b0882SAndroid Build Coastguard Worker   // Orders of charactergrams to extract. E.g., 2 means character bigrams, 3
35*993b0882SAndroid Build Coastguard Worker   // character trigrams etc.
36*993b0882SAndroid Build Coastguard Worker   std::vector<int> chargram_orders;
37*993b0882SAndroid Build Coastguard Worker 
38*993b0882SAndroid Build Coastguard Worker   // Whether to extract the token case feature.
39*993b0882SAndroid Build Coastguard Worker   bool extract_case_feature = false;
40*993b0882SAndroid Build Coastguard Worker 
41*993b0882SAndroid Build Coastguard Worker   // If true, will use the unicode-aware functionality for extracting features.
42*993b0882SAndroid Build Coastguard Worker   bool unicode_aware_features = false;
43*993b0882SAndroid Build Coastguard Worker 
44*993b0882SAndroid Build Coastguard Worker   // Whether to extract the selection mask feature.
45*993b0882SAndroid Build Coastguard Worker   bool extract_selection_mask_feature = false;
46*993b0882SAndroid Build Coastguard Worker 
47*993b0882SAndroid Build Coastguard Worker   // Regexp features to extract.
48*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> regexp_features;
49*993b0882SAndroid Build Coastguard Worker 
50*993b0882SAndroid Build Coastguard Worker   // Whether to remap digits to a single number.
51*993b0882SAndroid Build Coastguard Worker   bool remap_digits = false;
52*993b0882SAndroid Build Coastguard Worker 
53*993b0882SAndroid Build Coastguard Worker   // Whether to lowercase all tokens.
54*993b0882SAndroid Build Coastguard Worker   bool lowercase_tokens = false;
55*993b0882SAndroid Build Coastguard Worker 
56*993b0882SAndroid Build Coastguard Worker   // Maximum length of a word.
57*993b0882SAndroid Build Coastguard Worker   int max_word_length = 20;
58*993b0882SAndroid Build Coastguard Worker 
59*993b0882SAndroid Build Coastguard Worker   // List of allowed charactergrams. The extracted charactergrams are filtered
60*993b0882SAndroid Build Coastguard Worker   // using this list, and charactergrams that are not present are interpreted as
61*993b0882SAndroid Build Coastguard Worker   // out-of-vocabulary.
62*993b0882SAndroid Build Coastguard Worker   // If no allowed_chargrams are specified, all charactergrams are allowed.
63*993b0882SAndroid Build Coastguard Worker   std::unordered_set<std::string> allowed_chargrams;
64*993b0882SAndroid Build Coastguard Worker };
65*993b0882SAndroid Build Coastguard Worker 
66*993b0882SAndroid Build Coastguard Worker class TokenFeatureExtractor {
67*993b0882SAndroid Build Coastguard Worker  public:
68*993b0882SAndroid Build Coastguard Worker   // Des not take ownership of unilib, which must refer to a valid unilib
69*993b0882SAndroid Build Coastguard Worker   // instance that outlives this feature extractor.
70*993b0882SAndroid Build Coastguard Worker   explicit TokenFeatureExtractor(const TokenFeatureExtractorOptions& options,
71*993b0882SAndroid Build Coastguard Worker                                  const UniLib* unilib);
72*993b0882SAndroid Build Coastguard Worker 
73*993b0882SAndroid Build Coastguard Worker   // Extracts both the sparse (charactergram) and the dense features from a
74*993b0882SAndroid Build Coastguard Worker   // token. is_in_span is a bool indicator whether the token is a part of the
75*993b0882SAndroid Build Coastguard Worker   // selection span (true) or not (false).
76*993b0882SAndroid Build Coastguard Worker   // The sparse_features output is optional. Fails and returns false if
77*993b0882SAndroid Build Coastguard Worker   // dense_fatures in a nullptr.
78*993b0882SAndroid Build Coastguard Worker   bool Extract(const Token& token, bool is_in_span,
79*993b0882SAndroid Build Coastguard Worker                std::vector<int>* sparse_features,
80*993b0882SAndroid Build Coastguard Worker                std::vector<float>* dense_features) const;
81*993b0882SAndroid Build Coastguard Worker 
82*993b0882SAndroid Build Coastguard Worker   // Extracts the sparse (charactergram) features from the token.
83*993b0882SAndroid Build Coastguard Worker   std::vector<int> ExtractCharactergramFeatures(const Token& token) const;
84*993b0882SAndroid Build Coastguard Worker 
85*993b0882SAndroid Build Coastguard Worker   // Extracts the dense features from the token. is_in_span is a bool indicator
86*993b0882SAndroid Build Coastguard Worker   // whether the token is a part of the selection span (true) or not (false).
87*993b0882SAndroid Build Coastguard Worker   std::vector<float> ExtractDenseFeatures(const Token& token,
88*993b0882SAndroid Build Coastguard Worker                                           bool is_in_span) const;
89*993b0882SAndroid Build Coastguard Worker 
DenseFeaturesCount()90*993b0882SAndroid Build Coastguard Worker   int DenseFeaturesCount() const {
91*993b0882SAndroid Build Coastguard Worker     int feature_count =
92*993b0882SAndroid Build Coastguard Worker         options_.extract_case_feature + options_.extract_selection_mask_feature;
93*993b0882SAndroid Build Coastguard Worker     feature_count += regex_patterns_.size();
94*993b0882SAndroid Build Coastguard Worker     return feature_count;
95*993b0882SAndroid Build Coastguard Worker   }
96*993b0882SAndroid Build Coastguard Worker 
97*993b0882SAndroid Build Coastguard Worker  protected:
98*993b0882SAndroid Build Coastguard Worker   // Hashes given token to given number of buckets.
99*993b0882SAndroid Build Coastguard Worker   int HashToken(StringPiece token) const;
100*993b0882SAndroid Build Coastguard Worker 
101*993b0882SAndroid Build Coastguard Worker   // Extracts the charactergram features from the token in a non-unicode-aware
102*993b0882SAndroid Build Coastguard Worker   // way.
103*993b0882SAndroid Build Coastguard Worker   std::vector<int> ExtractCharactergramFeaturesAscii(const Token& token) const;
104*993b0882SAndroid Build Coastguard Worker 
105*993b0882SAndroid Build Coastguard Worker   // Extracts the charactergram features from the token in a unicode-aware way.
106*993b0882SAndroid Build Coastguard Worker   std::vector<int> ExtractCharactergramFeaturesUnicode(
107*993b0882SAndroid Build Coastguard Worker       const Token& token) const;
108*993b0882SAndroid Build Coastguard Worker 
109*993b0882SAndroid Build Coastguard Worker  private:
110*993b0882SAndroid Build Coastguard Worker   TokenFeatureExtractorOptions options_;
111*993b0882SAndroid Build Coastguard Worker   std::vector<std::unique_ptr<UniLib::RegexPattern>> regex_patterns_;
112*993b0882SAndroid Build Coastguard Worker   const UniLib& unilib_;
113*993b0882SAndroid Build Coastguard Worker };
114*993b0882SAndroid Build Coastguard Worker 
115*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
116*993b0882SAndroid Build Coastguard Worker 
117*993b0882SAndroid Build Coastguard Worker #endif  // LIBTEXTCLASSIFIER_UTILS_TOKEN_FEATURE_EXTRACTOR_H_
118