xref: /aosp_15_r20/external/libtextclassifier/native/utils/tflite/skipgram_finder.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/skipgram_finder.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <cctype>
20*993b0882SAndroid Build Coastguard Worker #include <deque>
21*993b0882SAndroid Build Coastguard Worker #include <string>
22*993b0882SAndroid Build Coastguard Worker #include <vector>
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker #include "utils/strings/utf8.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unilib-common.h"
26*993b0882SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h"
27*993b0882SAndroid Build Coastguard Worker #include "absl/container/flat_hash_set.h"
28*993b0882SAndroid Build Coastguard Worker #include "absl/strings/match.h"
29*993b0882SAndroid Build Coastguard Worker #include "absl/strings/str_split.h"
30*993b0882SAndroid Build Coastguard Worker #include "absl/strings/string_view.h"
31*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/string_util.h"
32*993b0882SAndroid Build Coastguard Worker 
33*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
34*993b0882SAndroid Build Coastguard Worker namespace {
35*993b0882SAndroid Build Coastguard Worker 
36*993b0882SAndroid Build Coastguard Worker using ::tflite::StringRef;
37*993b0882SAndroid Build Coastguard Worker 
PreprocessToken(std::string & token)38*993b0882SAndroid Build Coastguard Worker void PreprocessToken(std::string& token) {
39*993b0882SAndroid Build Coastguard Worker   size_t in = 0;
40*993b0882SAndroid Build Coastguard Worker   size_t out = 0;
41*993b0882SAndroid Build Coastguard Worker   while (in < token.size()) {
42*993b0882SAndroid Build Coastguard Worker     const char* in_data = token.data() + in;
43*993b0882SAndroid Build Coastguard Worker     const int n = GetNumBytesForUTF8Char(in_data);
44*993b0882SAndroid Build Coastguard Worker     if (n < 0 || n > token.size() - in) {
45*993b0882SAndroid Build Coastguard Worker       // Invalid Utf8 sequence.
46*993b0882SAndroid Build Coastguard Worker       break;
47*993b0882SAndroid Build Coastguard Worker     }
48*993b0882SAndroid Build Coastguard Worker     in += n;
49*993b0882SAndroid Build Coastguard Worker     const char32 r = ValidCharToRune(in_data);
50*993b0882SAndroid Build Coastguard Worker     if (IsPunctuation(r)) {
51*993b0882SAndroid Build Coastguard Worker       continue;
52*993b0882SAndroid Build Coastguard Worker     }
53*993b0882SAndroid Build Coastguard Worker     const char32 rl = ToLower(r);
54*993b0882SAndroid Build Coastguard Worker     char output_buffer[4];
55*993b0882SAndroid Build Coastguard Worker     int encoded_length = ValidRuneToChar(rl, output_buffer);
56*993b0882SAndroid Build Coastguard Worker     if (encoded_length > n) {
57*993b0882SAndroid Build Coastguard Worker       // This is a hack, but there are exactly two unicode characters whose
58*993b0882SAndroid Build Coastguard Worker       // lowercase versions have longer UTF-8 encodings (0x23a to 0x2c65,
59*993b0882SAndroid Build Coastguard Worker       // 0x23e to 0x2c66).  So, to avoid sizing issues, they're not lowercased.
60*993b0882SAndroid Build Coastguard Worker       encoded_length = ValidRuneToChar(r, output_buffer);
61*993b0882SAndroid Build Coastguard Worker     }
62*993b0882SAndroid Build Coastguard Worker     memcpy(token.data() + out, output_buffer, encoded_length);
63*993b0882SAndroid Build Coastguard Worker     out += encoded_length;
64*993b0882SAndroid Build Coastguard Worker   }
65*993b0882SAndroid Build Coastguard Worker 
66*993b0882SAndroid Build Coastguard Worker   size_t remaining = token.size() - in;
67*993b0882SAndroid Build Coastguard Worker   if (remaining > 0) {
68*993b0882SAndroid Build Coastguard Worker     memmove(token.data() + out, token.data() + in, remaining);
69*993b0882SAndroid Build Coastguard Worker     out += remaining;
70*993b0882SAndroid Build Coastguard Worker   }
71*993b0882SAndroid Build Coastguard Worker   token.resize(out);
72*993b0882SAndroid Build Coastguard Worker }
73*993b0882SAndroid Build Coastguard Worker 
74*993b0882SAndroid Build Coastguard Worker }  // namespace
75*993b0882SAndroid Build Coastguard Worker 
AddSkipgram(const std::string & skipgram,int category)76*993b0882SAndroid Build Coastguard Worker void SkipgramFinder::AddSkipgram(const std::string& skipgram, int category) {
77*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> tokens = absl::StrSplit(skipgram, ' ');
78*993b0882SAndroid Build Coastguard Worker 
79*993b0882SAndroid Build Coastguard Worker   // Store the skipgram in a trie-like structure that uses tokens as the
80*993b0882SAndroid Build Coastguard Worker   // edge labels, instead of characters.  Each node represents a skipgram made
81*993b0882SAndroid Build Coastguard Worker   // from the tokens used to reach the node, and stores the categories the
82*993b0882SAndroid Build Coastguard Worker   // skipgram is associated with.
83*993b0882SAndroid Build Coastguard Worker   TrieNode* cur = &skipgram_trie_;
84*993b0882SAndroid Build Coastguard Worker   for (auto& token : tokens) {
85*993b0882SAndroid Build Coastguard Worker     if (absl::EndsWith(token, ".*")) {
86*993b0882SAndroid Build Coastguard Worker       token.resize(token.size() - 2);
87*993b0882SAndroid Build Coastguard Worker       PreprocessToken(token);
88*993b0882SAndroid Build Coastguard Worker       auto iter = cur->prefix_to_node.find(token);
89*993b0882SAndroid Build Coastguard Worker       if (iter != cur->prefix_to_node.end()) {
90*993b0882SAndroid Build Coastguard Worker         cur = &iter->second;
91*993b0882SAndroid Build Coastguard Worker       } else {
92*993b0882SAndroid Build Coastguard Worker         cur = &cur->prefix_to_node
93*993b0882SAndroid Build Coastguard Worker                    .emplace(std::piecewise_construct,
94*993b0882SAndroid Build Coastguard Worker                             std::forward_as_tuple(token), std::make_tuple<>())
95*993b0882SAndroid Build Coastguard Worker                    .first->second;
96*993b0882SAndroid Build Coastguard Worker       }
97*993b0882SAndroid Build Coastguard Worker       continue;
98*993b0882SAndroid Build Coastguard Worker     }
99*993b0882SAndroid Build Coastguard Worker 
100*993b0882SAndroid Build Coastguard Worker     PreprocessToken(token);
101*993b0882SAndroid Build Coastguard Worker     auto iter = cur->token_to_node.find(token);
102*993b0882SAndroid Build Coastguard Worker     if (iter != cur->token_to_node.end()) {
103*993b0882SAndroid Build Coastguard Worker       cur = &iter->second;
104*993b0882SAndroid Build Coastguard Worker     } else {
105*993b0882SAndroid Build Coastguard Worker       cur = &cur->token_to_node
106*993b0882SAndroid Build Coastguard Worker                  .emplace(std::piecewise_construct,
107*993b0882SAndroid Build Coastguard Worker                           std::forward_as_tuple(token), std::make_tuple<>())
108*993b0882SAndroid Build Coastguard Worker                  .first->second;
109*993b0882SAndroid Build Coastguard Worker     }
110*993b0882SAndroid Build Coastguard Worker   }
111*993b0882SAndroid Build Coastguard Worker   cur->categories.insert(category);
112*993b0882SAndroid Build Coastguard Worker }
113*993b0882SAndroid Build Coastguard Worker 
FindSkipgrams(const std::string & input) const114*993b0882SAndroid Build Coastguard Worker absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
115*993b0882SAndroid Build Coastguard Worker     const std::string& input) const {
116*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> tokens = absl::StrSplit(input, ' ');
117*993b0882SAndroid Build Coastguard Worker   std::vector<absl::string_view> sv_tokens;
118*993b0882SAndroid Build Coastguard Worker   sv_tokens.reserve(tokens.size());
119*993b0882SAndroid Build Coastguard Worker   for (auto& token : tokens) {
120*993b0882SAndroid Build Coastguard Worker     PreprocessToken(token);
121*993b0882SAndroid Build Coastguard Worker     sv_tokens.emplace_back(token.data(), token.size());
122*993b0882SAndroid Build Coastguard Worker   }
123*993b0882SAndroid Build Coastguard Worker   return FindSkipgrams(sv_tokens);
124*993b0882SAndroid Build Coastguard Worker }
125*993b0882SAndroid Build Coastguard Worker 
FindSkipgrams(const std::vector<StringRef> & tokens) const126*993b0882SAndroid Build Coastguard Worker absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
127*993b0882SAndroid Build Coastguard Worker     const std::vector<StringRef>& tokens) const {
128*993b0882SAndroid Build Coastguard Worker   std::vector<absl::string_view> sv_tokens;
129*993b0882SAndroid Build Coastguard Worker   sv_tokens.reserve(tokens.size());
130*993b0882SAndroid Build Coastguard Worker   for (auto& token : tokens) {
131*993b0882SAndroid Build Coastguard Worker     sv_tokens.emplace_back(token.str, token.len);
132*993b0882SAndroid Build Coastguard Worker   }
133*993b0882SAndroid Build Coastguard Worker   return FindSkipgrams(sv_tokens);
134*993b0882SAndroid Build Coastguard Worker }
135*993b0882SAndroid Build Coastguard Worker 
FindSkipgrams(const std::vector<absl::string_view> & tokens) const136*993b0882SAndroid Build Coastguard Worker absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
137*993b0882SAndroid Build Coastguard Worker     const std::vector<absl::string_view>& tokens) const {
138*993b0882SAndroid Build Coastguard Worker   absl::flat_hash_set<int> categories;
139*993b0882SAndroid Build Coastguard Worker 
140*993b0882SAndroid Build Coastguard Worker   // Tracks skipgram prefixes and the index of their last token.
141*993b0882SAndroid Build Coastguard Worker   std::deque<std::pair<int, const TrieNode*>> indices_and_skipgrams;
142*993b0882SAndroid Build Coastguard Worker 
143*993b0882SAndroid Build Coastguard Worker   for (int token_i = 0; token_i < tokens.size(); token_i++) {
144*993b0882SAndroid Build Coastguard Worker     const absl::string_view& token = tokens[token_i];
145*993b0882SAndroid Build Coastguard Worker 
146*993b0882SAndroid Build Coastguard Worker     std::vector<absl::string_view> token_prefixes;
147*993b0882SAndroid Build Coastguard Worker     {
148*993b0882SAndroid Build Coastguard Worker       const char* s = token.data();
149*993b0882SAndroid Build Coastguard Worker       int n = token.size();
150*993b0882SAndroid Build Coastguard Worker       while (n > 0) {
151*993b0882SAndroid Build Coastguard Worker         const int rlen = GetNumBytesForUTF8Char(s);
152*993b0882SAndroid Build Coastguard Worker         if (rlen < 0 || rlen > n) {
153*993b0882SAndroid Build Coastguard Worker           // Invalid UTF8.
154*993b0882SAndroid Build Coastguard Worker           break;
155*993b0882SAndroid Build Coastguard Worker         }
156*993b0882SAndroid Build Coastguard Worker         n -= rlen;
157*993b0882SAndroid Build Coastguard Worker         s += rlen;
158*993b0882SAndroid Build Coastguard Worker         token_prefixes.emplace_back(token.data(), token.size() - n);
159*993b0882SAndroid Build Coastguard Worker       }
160*993b0882SAndroid Build Coastguard Worker     }
161*993b0882SAndroid Build Coastguard Worker 
162*993b0882SAndroid Build Coastguard Worker     // Drop any skipgrams prefixes which would skip more than `max_skip_size_`
163*993b0882SAndroid Build Coastguard Worker     // tokens between the end of the prefix and the current token.
164*993b0882SAndroid Build Coastguard Worker     while (!indices_and_skipgrams.empty()) {
165*993b0882SAndroid Build Coastguard Worker       if (indices_and_skipgrams.front().first + max_skip_size_ + 1 < token_i) {
166*993b0882SAndroid Build Coastguard Worker         indices_and_skipgrams.pop_front();
167*993b0882SAndroid Build Coastguard Worker       } else {
168*993b0882SAndroid Build Coastguard Worker         break;
169*993b0882SAndroid Build Coastguard Worker       }
170*993b0882SAndroid Build Coastguard Worker     }
171*993b0882SAndroid Build Coastguard Worker 
172*993b0882SAndroid Build Coastguard Worker     // Check if we can form a valid skipgram prefix (or skipgram) by adding
173*993b0882SAndroid Build Coastguard Worker     // the current token to any of the existing skipgram prefixes, or
174*993b0882SAndroid Build Coastguard Worker     // if the current token is a valid skipgram prefix (or skipgram).
175*993b0882SAndroid Build Coastguard Worker     size_t size = indices_and_skipgrams.size();
176*993b0882SAndroid Build Coastguard Worker     for (size_t skipgram_i = 0; skipgram_i <= size; skipgram_i++) {
177*993b0882SAndroid Build Coastguard Worker       const auto& node = skipgram_i < size
178*993b0882SAndroid Build Coastguard Worker                              ? *indices_and_skipgrams[skipgram_i].second
179*993b0882SAndroid Build Coastguard Worker                              : skipgram_trie_;
180*993b0882SAndroid Build Coastguard Worker 
181*993b0882SAndroid Build Coastguard Worker       auto iter = node.token_to_node.find(token);
182*993b0882SAndroid Build Coastguard Worker       if (iter != node.token_to_node.end()) {
183*993b0882SAndroid Build Coastguard Worker         categories.insert(iter->second.categories.begin(),
184*993b0882SAndroid Build Coastguard Worker                           iter->second.categories.end());
185*993b0882SAndroid Build Coastguard Worker         indices_and_skipgrams.push_back(std::make_pair(token_i, &iter->second));
186*993b0882SAndroid Build Coastguard Worker       }
187*993b0882SAndroid Build Coastguard Worker 
188*993b0882SAndroid Build Coastguard Worker       for (const auto& token_prefix : token_prefixes) {
189*993b0882SAndroid Build Coastguard Worker         auto iter = node.prefix_to_node.find(token_prefix);
190*993b0882SAndroid Build Coastguard Worker         if (iter != node.prefix_to_node.end()) {
191*993b0882SAndroid Build Coastguard Worker           categories.insert(iter->second.categories.begin(),
192*993b0882SAndroid Build Coastguard Worker                             iter->second.categories.end());
193*993b0882SAndroid Build Coastguard Worker           indices_and_skipgrams.push_back(
194*993b0882SAndroid Build Coastguard Worker               std::make_pair(token_i, &iter->second));
195*993b0882SAndroid Build Coastguard Worker         }
196*993b0882SAndroid Build Coastguard Worker       }
197*993b0882SAndroid Build Coastguard Worker     }
198*993b0882SAndroid Build Coastguard Worker   }
199*993b0882SAndroid Build Coastguard Worker 
200*993b0882SAndroid Build Coastguard Worker   return categories;
201*993b0882SAndroid Build Coastguard Worker }
202*993b0882SAndroid Build Coastguard Worker 
203*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
204