xref: /aosp_15_r20/external/libtextclassifier/native/utils/wordpiece_tokenizer.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "utils/wordpiece_tokenizer.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
20*993b0882SAndroid Build Coastguard Worker #include "absl/strings/str_cat.h"
21*993b0882SAndroid Build Coastguard Worker #include "absl/strings/str_join.h"
22*993b0882SAndroid Build Coastguard Worker #include "absl/strings/string_view.h"
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
25*993b0882SAndroid Build Coastguard Worker 
26*993b0882SAndroid Build Coastguard Worker namespace {
27*993b0882SAndroid Build Coastguard Worker 
Lookup(int byte_start,int byte_end,const absl::string_view token,const std::string & suffix_indicator,const WordpieceVocab * vocab_map,bool * in_vocab)28*993b0882SAndroid Build Coastguard Worker LookupStatus Lookup(int byte_start, int byte_end, const absl::string_view token,
29*993b0882SAndroid Build Coastguard Worker                     const std::string& suffix_indicator,
30*993b0882SAndroid Build Coastguard Worker                     const WordpieceVocab* vocab_map, bool* in_vocab) {
31*993b0882SAndroid Build Coastguard Worker   int byte_len = byte_end - byte_start;
32*993b0882SAndroid Build Coastguard Worker   absl::string_view substr(token.data() + byte_start, byte_len);
33*993b0882SAndroid Build Coastguard Worker   std::string lookup_value;
34*993b0882SAndroid Build Coastguard Worker   if (byte_start > 0) {
35*993b0882SAndroid Build Coastguard Worker     lookup_value = absl::StrCat(suffix_indicator, substr);
36*993b0882SAndroid Build Coastguard Worker   } else {
37*993b0882SAndroid Build Coastguard Worker     // absl::CopyToString
38*993b0882SAndroid Build Coastguard Worker     lookup_value.assign(substr.begin(), substr.end());
39*993b0882SAndroid Build Coastguard Worker   }
40*993b0882SAndroid Build Coastguard Worker   return vocab_map->Contains(lookup_value, in_vocab);
41*993b0882SAndroid Build Coastguard Worker }
42*993b0882SAndroid Build Coastguard Worker 
43*993b0882SAndroid Build Coastguard Worker // Sets byte_end to the longest byte sequence which:
44*993b0882SAndroid Build Coastguard Worker // 1) is a proper UTF8 sequence
45*993b0882SAndroid Build Coastguard Worker // 2) is in the vocab OR if split_unknown_characters is true, is a single
46*993b0882SAndroid Build Coastguard Worker //    UTF8 character.
47*993b0882SAndroid Build Coastguard Worker // If no match is found, found_match is set to false.
LongestMatchStartingAt(int byte_start,const absl::string_view token,const std::string & suffix_indicator,const int max_chars_per_subtoken,bool split_unknown_characters,const WordpieceVocab * vocab_map,int * byte_end,bool * found_match,bool * match_is_unknown_character)48*993b0882SAndroid Build Coastguard Worker LookupStatus LongestMatchStartingAt(
49*993b0882SAndroid Build Coastguard Worker     int byte_start, const absl::string_view token,
50*993b0882SAndroid Build Coastguard Worker     const std::string& suffix_indicator, const int max_chars_per_subtoken,
51*993b0882SAndroid Build Coastguard Worker     bool split_unknown_characters, const WordpieceVocab* vocab_map,
52*993b0882SAndroid Build Coastguard Worker     int* byte_end, bool* found_match, bool* match_is_unknown_character) {
53*993b0882SAndroid Build Coastguard Worker   *match_is_unknown_character = false;
54*993b0882SAndroid Build Coastguard Worker   *found_match = false;
55*993b0882SAndroid Build Coastguard Worker   const UnicodeText unicode_token =
56*993b0882SAndroid Build Coastguard Worker       UTF8ToUnicodeText(token.substr(byte_start), /*do_copy=*/false);
57*993b0882SAndroid Build Coastguard Worker   std::vector<int32_t> byte_ends;
58*993b0882SAndroid Build Coastguard Worker   int32_t codepoint_offset = byte_start;
59*993b0882SAndroid Build Coastguard Worker   for (auto it = unicode_token.begin(); it != unicode_token.end(); ++it) {
60*993b0882SAndroid Build Coastguard Worker     codepoint_offset += it.utf8_length();
61*993b0882SAndroid Build Coastguard Worker     byte_ends.push_back(codepoint_offset);
62*993b0882SAndroid Build Coastguard Worker     if (max_chars_per_subtoken > 0 &&
63*993b0882SAndroid Build Coastguard Worker         byte_ends.size() == max_chars_per_subtoken) {
64*993b0882SAndroid Build Coastguard Worker       // If the max bytes of a subtoken is known, do not search beyond that
65*993b0882SAndroid Build Coastguard Worker       // length.
66*993b0882SAndroid Build Coastguard Worker       break;
67*993b0882SAndroid Build Coastguard Worker     }
68*993b0882SAndroid Build Coastguard Worker   }
69*993b0882SAndroid Build Coastguard Worker   int n = byte_ends.size();
70*993b0882SAndroid Build Coastguard Worker   for (int i = n - 1; i >= 0; i--) {
71*993b0882SAndroid Build Coastguard Worker     bool in_vocab;
72*993b0882SAndroid Build Coastguard Worker     auto status = Lookup(byte_start, byte_ends[i], token, suffix_indicator,
73*993b0882SAndroid Build Coastguard Worker                          vocab_map, &in_vocab);
74*993b0882SAndroid Build Coastguard Worker     if (!status.success) return status;
75*993b0882SAndroid Build Coastguard Worker     if (in_vocab) {
76*993b0882SAndroid Build Coastguard Worker       *byte_end = byte_ends[i];
77*993b0882SAndroid Build Coastguard Worker       *found_match = true;
78*993b0882SAndroid Build Coastguard Worker       return LookupStatus::OK();
79*993b0882SAndroid Build Coastguard Worker     }
80*993b0882SAndroid Build Coastguard Worker     if (i == 0 && split_unknown_characters) {
81*993b0882SAndroid Build Coastguard Worker       *byte_end = byte_ends[0];
82*993b0882SAndroid Build Coastguard Worker       *found_match = true;
83*993b0882SAndroid Build Coastguard Worker       *match_is_unknown_character = true;
84*993b0882SAndroid Build Coastguard Worker       return LookupStatus::OK();
85*993b0882SAndroid Build Coastguard Worker     }
86*993b0882SAndroid Build Coastguard Worker   }
87*993b0882SAndroid Build Coastguard Worker   return LookupStatus::OK();
88*993b0882SAndroid Build Coastguard Worker }
89*993b0882SAndroid Build Coastguard Worker 
90*993b0882SAndroid Build Coastguard Worker // Sets the outputs 'begin_offset', 'end_offset' and 'num_word_pieces' when no
91*993b0882SAndroid Build Coastguard Worker // token is found.
NoTokenFound(const absl::string_view token,bool use_unknown_token,const std::string & unknown_token,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)92*993b0882SAndroid Build Coastguard Worker LookupStatus NoTokenFound(const absl::string_view token, bool use_unknown_token,
93*993b0882SAndroid Build Coastguard Worker                           const std::string& unknown_token,
94*993b0882SAndroid Build Coastguard Worker                           std::vector<std::string>* subwords,
95*993b0882SAndroid Build Coastguard Worker                           std::vector<int>* begin_offset,
96*993b0882SAndroid Build Coastguard Worker                           std::vector<int>* end_offset, int* num_word_pieces) {
97*993b0882SAndroid Build Coastguard Worker   begin_offset->push_back(0);
98*993b0882SAndroid Build Coastguard Worker   if (use_unknown_token) {
99*993b0882SAndroid Build Coastguard Worker     subwords->push_back(unknown_token);
100*993b0882SAndroid Build Coastguard Worker     end_offset->push_back(token.length());
101*993b0882SAndroid Build Coastguard Worker   } else {
102*993b0882SAndroid Build Coastguard Worker     subwords->emplace_back(token.data(), token.length());
103*993b0882SAndroid Build Coastguard Worker     end_offset->push_back(token.length());
104*993b0882SAndroid Build Coastguard Worker   }
105*993b0882SAndroid Build Coastguard Worker   ++(*num_word_pieces);
106*993b0882SAndroid Build Coastguard Worker 
107*993b0882SAndroid Build Coastguard Worker   return LookupStatus::OK();
108*993b0882SAndroid Build Coastguard Worker }
109*993b0882SAndroid Build Coastguard Worker 
110*993b0882SAndroid Build Coastguard Worker // When a subword is found, this helper function will add the outputs to
111*993b0882SAndroid Build Coastguard Worker // 'subwords', 'begin_offset' and 'end_offset'.
AddWord(const absl::string_view token,int byte_start,int byte_end,const std::string & suffix_indicator,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset)112*993b0882SAndroid Build Coastguard Worker void AddWord(const absl::string_view token, int byte_start, int byte_end,
113*993b0882SAndroid Build Coastguard Worker              const std::string& suffix_indicator,
114*993b0882SAndroid Build Coastguard Worker              std::vector<std::string>* subwords, std::vector<int>* begin_offset,
115*993b0882SAndroid Build Coastguard Worker              std::vector<int>* end_offset) {
116*993b0882SAndroid Build Coastguard Worker   begin_offset->push_back(byte_start);
117*993b0882SAndroid Build Coastguard Worker   int len = byte_end - byte_start;
118*993b0882SAndroid Build Coastguard Worker 
119*993b0882SAndroid Build Coastguard Worker   if (byte_start > 0) {
120*993b0882SAndroid Build Coastguard Worker     // Prepend suffix_indicator if the token is within a word.
121*993b0882SAndroid Build Coastguard Worker     subwords->push_back(::absl::StrCat(
122*993b0882SAndroid Build Coastguard Worker         suffix_indicator, absl::string_view(token.data() + byte_start, len)));
123*993b0882SAndroid Build Coastguard Worker   } else {
124*993b0882SAndroid Build Coastguard Worker     subwords->emplace_back(token.data(), len);
125*993b0882SAndroid Build Coastguard Worker   }
126*993b0882SAndroid Build Coastguard Worker   end_offset->push_back(byte_end);
127*993b0882SAndroid Build Coastguard Worker }
128*993b0882SAndroid Build Coastguard Worker 
129*993b0882SAndroid Build Coastguard Worker // Adds a single unknown character subword, found when split_unknown_characters
130*993b0882SAndroid Build Coastguard Worker // is true.
AddUnknownCharacter(const absl::string_view token,int byte_start,int byte_end,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset)131*993b0882SAndroid Build Coastguard Worker void AddUnknownCharacter(const absl::string_view token, int byte_start,
132*993b0882SAndroid Build Coastguard Worker                          int byte_end, const std::string& suffix_indicator,
133*993b0882SAndroid Build Coastguard Worker                          bool use_unknown_token,
134*993b0882SAndroid Build Coastguard Worker                          const std::string& unknown_token,
135*993b0882SAndroid Build Coastguard Worker                          std::vector<std::string>* subwords,
136*993b0882SAndroid Build Coastguard Worker                          std::vector<int>* begin_offset,
137*993b0882SAndroid Build Coastguard Worker                          std::vector<int>* end_offset) {
138*993b0882SAndroid Build Coastguard Worker   begin_offset->push_back(byte_start);
139*993b0882SAndroid Build Coastguard Worker   end_offset->push_back(byte_end);
140*993b0882SAndroid Build Coastguard Worker   int len = byte_end - byte_start;
141*993b0882SAndroid Build Coastguard Worker   if (use_unknown_token) {
142*993b0882SAndroid Build Coastguard Worker     if (byte_start > 0) {
143*993b0882SAndroid Build Coastguard Worker       // Prepend suffix_indicator if the character is within a word.
144*993b0882SAndroid Build Coastguard Worker       subwords->push_back(::absl::StrCat(suffix_indicator, unknown_token));
145*993b0882SAndroid Build Coastguard Worker     } else {
146*993b0882SAndroid Build Coastguard Worker       subwords->push_back(unknown_token);
147*993b0882SAndroid Build Coastguard Worker     }
148*993b0882SAndroid Build Coastguard Worker   } else {
149*993b0882SAndroid Build Coastguard Worker     if (byte_start > 0) {
150*993b0882SAndroid Build Coastguard Worker       // Prepend suffix_indicator if the character is within a word.
151*993b0882SAndroid Build Coastguard Worker       subwords->push_back(::absl::StrCat(
152*993b0882SAndroid Build Coastguard Worker           suffix_indicator, absl::string_view(token.data() + byte_start, len)));
153*993b0882SAndroid Build Coastguard Worker     } else {
154*993b0882SAndroid Build Coastguard Worker       subwords->emplace_back(token.data(), len);
155*993b0882SAndroid Build Coastguard Worker     }
156*993b0882SAndroid Build Coastguard Worker   }
157*993b0882SAndroid Build Coastguard Worker }
158*993b0882SAndroid Build Coastguard Worker 
TokenizeL2RGreedy(const absl::string_view token,const int max_bytes_per_token,const int max_chars_per_subtoken,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,bool split_unknown_characters,const WordpieceVocab * vocab_map,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)159*993b0882SAndroid Build Coastguard Worker LookupStatus TokenizeL2RGreedy(
160*993b0882SAndroid Build Coastguard Worker     const absl::string_view token, const int max_bytes_per_token,
161*993b0882SAndroid Build Coastguard Worker     const int max_chars_per_subtoken, const std::string& suffix_indicator,
162*993b0882SAndroid Build Coastguard Worker     bool use_unknown_token, const std::string& unknown_token,
163*993b0882SAndroid Build Coastguard Worker     bool split_unknown_characters, const WordpieceVocab* vocab_map,
164*993b0882SAndroid Build Coastguard Worker     std::vector<std::string>* subwords, std::vector<int>* begin_offset,
165*993b0882SAndroid Build Coastguard Worker     std::vector<int>* end_offset, int* num_word_pieces) {
166*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> candidate_subwords;
167*993b0882SAndroid Build Coastguard Worker   std::vector<int> candidate_begin_offsets;
168*993b0882SAndroid Build Coastguard Worker   std::vector<int> candidate_end_offsets;
169*993b0882SAndroid Build Coastguard Worker   const int token_len = token.length();
170*993b0882SAndroid Build Coastguard Worker   for (int byte_start = 0; byte_start < token_len;) {
171*993b0882SAndroid Build Coastguard Worker     int byte_end;
172*993b0882SAndroid Build Coastguard Worker     bool found_subword;
173*993b0882SAndroid Build Coastguard Worker     bool match_is_unknown_character;
174*993b0882SAndroid Build Coastguard Worker     auto status = LongestMatchStartingAt(
175*993b0882SAndroid Build Coastguard Worker         byte_start, token, suffix_indicator, max_chars_per_subtoken,
176*993b0882SAndroid Build Coastguard Worker         split_unknown_characters, vocab_map, &byte_end, &found_subword,
177*993b0882SAndroid Build Coastguard Worker         &match_is_unknown_character);
178*993b0882SAndroid Build Coastguard Worker     if (!status.success) return status;
179*993b0882SAndroid Build Coastguard Worker     if (found_subword) {
180*993b0882SAndroid Build Coastguard Worker       if (match_is_unknown_character) {
181*993b0882SAndroid Build Coastguard Worker         AddUnknownCharacter(token, byte_start, byte_end, suffix_indicator,
182*993b0882SAndroid Build Coastguard Worker                             use_unknown_token, unknown_token,
183*993b0882SAndroid Build Coastguard Worker                             &candidate_subwords, &candidate_begin_offsets,
184*993b0882SAndroid Build Coastguard Worker                             &candidate_end_offsets);
185*993b0882SAndroid Build Coastguard Worker       } else {
186*993b0882SAndroid Build Coastguard Worker         AddWord(token, byte_start, byte_end, suffix_indicator,
187*993b0882SAndroid Build Coastguard Worker                 &candidate_subwords, &candidate_begin_offsets,
188*993b0882SAndroid Build Coastguard Worker                 &candidate_end_offsets);
189*993b0882SAndroid Build Coastguard Worker       }
190*993b0882SAndroid Build Coastguard Worker       byte_start = byte_end;
191*993b0882SAndroid Build Coastguard Worker     } else {
192*993b0882SAndroid Build Coastguard Worker       return NoTokenFound(token, use_unknown_token, unknown_token, subwords,
193*993b0882SAndroid Build Coastguard Worker                           begin_offset, end_offset, num_word_pieces);
194*993b0882SAndroid Build Coastguard Worker     }
195*993b0882SAndroid Build Coastguard Worker   }
196*993b0882SAndroid Build Coastguard Worker 
197*993b0882SAndroid Build Coastguard Worker   subwords->insert(subwords->end(), candidate_subwords.begin(),
198*993b0882SAndroid Build Coastguard Worker                    candidate_subwords.end());
199*993b0882SAndroid Build Coastguard Worker   begin_offset->insert(begin_offset->end(), candidate_begin_offsets.begin(),
200*993b0882SAndroid Build Coastguard Worker                        candidate_begin_offsets.end());
201*993b0882SAndroid Build Coastguard Worker   end_offset->insert(end_offset->end(), candidate_end_offsets.begin(),
202*993b0882SAndroid Build Coastguard Worker                      candidate_end_offsets.end());
203*993b0882SAndroid Build Coastguard Worker   *num_word_pieces += candidate_subwords.size();
204*993b0882SAndroid Build Coastguard Worker   return LookupStatus::OK();
205*993b0882SAndroid Build Coastguard Worker }
206*993b0882SAndroid Build Coastguard Worker 
207*993b0882SAndroid Build Coastguard Worker }  // namespace
208*993b0882SAndroid Build Coastguard Worker 
WordpieceTokenize(const absl::string_view token,const int max_bytes_per_token,const int max_chars_per_subtoken,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,bool split_unknown_characters,const WordpieceVocab * vocab_map,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)209*993b0882SAndroid Build Coastguard Worker LookupStatus WordpieceTokenize(
210*993b0882SAndroid Build Coastguard Worker     const absl::string_view token, const int max_bytes_per_token,
211*993b0882SAndroid Build Coastguard Worker     const int max_chars_per_subtoken, const std::string& suffix_indicator,
212*993b0882SAndroid Build Coastguard Worker     bool use_unknown_token, const std::string& unknown_token,
213*993b0882SAndroid Build Coastguard Worker     bool split_unknown_characters, const WordpieceVocab* vocab_map,
214*993b0882SAndroid Build Coastguard Worker     std::vector<std::string>* subwords, std::vector<int>* begin_offset,
215*993b0882SAndroid Build Coastguard Worker     std::vector<int>* end_offset, int* num_word_pieces) {
216*993b0882SAndroid Build Coastguard Worker   int token_len = token.size();
217*993b0882SAndroid Build Coastguard Worker   if (token_len > max_bytes_per_token) {
218*993b0882SAndroid Build Coastguard Worker     begin_offset->push_back(0);
219*993b0882SAndroid Build Coastguard Worker     *num_word_pieces = 1;
220*993b0882SAndroid Build Coastguard Worker     if (use_unknown_token) {
221*993b0882SAndroid Build Coastguard Worker       subwords->emplace_back(unknown_token);
222*993b0882SAndroid Build Coastguard Worker     } else {
223*993b0882SAndroid Build Coastguard Worker       subwords->emplace_back(token);
224*993b0882SAndroid Build Coastguard Worker     }
225*993b0882SAndroid Build Coastguard Worker     end_offset->push_back(token.size());
226*993b0882SAndroid Build Coastguard Worker     return LookupStatus::OK();
227*993b0882SAndroid Build Coastguard Worker   }
228*993b0882SAndroid Build Coastguard Worker   return TokenizeL2RGreedy(token, max_bytes_per_token, max_chars_per_subtoken,
229*993b0882SAndroid Build Coastguard Worker                            suffix_indicator, use_unknown_token, unknown_token,
230*993b0882SAndroid Build Coastguard Worker                            split_unknown_characters, vocab_map, subwords,
231*993b0882SAndroid Build Coastguard Worker                            begin_offset, end_offset, num_word_pieces);
232*993b0882SAndroid Build Coastguard Worker }
233*993b0882SAndroid Build Coastguard Worker 
WordpieceTokenize(const absl::string_view token,const int max_bytes_per_token,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,const WordpieceVocab * vocab_map,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)234*993b0882SAndroid Build Coastguard Worker LookupStatus WordpieceTokenize(
235*993b0882SAndroid Build Coastguard Worker     const absl::string_view token, const int max_bytes_per_token,
236*993b0882SAndroid Build Coastguard Worker     const std::string& suffix_indicator, bool use_unknown_token,
237*993b0882SAndroid Build Coastguard Worker     const std::string& unknown_token, const WordpieceVocab* vocab_map,
238*993b0882SAndroid Build Coastguard Worker     std::vector<std::string>* subwords, std::vector<int>* begin_offset,
239*993b0882SAndroid Build Coastguard Worker     std::vector<int>* end_offset, int* num_word_pieces) {
240*993b0882SAndroid Build Coastguard Worker   return WordpieceTokenize(token, max_bytes_per_token,
241*993b0882SAndroid Build Coastguard Worker                            /* max_chars_per_subtoken= */ 0, suffix_indicator,
242*993b0882SAndroid Build Coastguard Worker                            use_unknown_token, unknown_token,
243*993b0882SAndroid Build Coastguard Worker                            /* split_unknown_characters= */ false, vocab_map,
244*993b0882SAndroid Build Coastguard Worker                            subwords, begin_offset, end_offset, num_word_pieces);
245*993b0882SAndroid Build Coastguard Worker }
246*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
247