1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "utils/wordpiece_tokenizer.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unicodetext.h"
20*993b0882SAndroid Build Coastguard Worker #include "absl/strings/str_cat.h"
21*993b0882SAndroid Build Coastguard Worker #include "absl/strings/str_join.h"
22*993b0882SAndroid Build Coastguard Worker #include "absl/strings/string_view.h"
23*993b0882SAndroid Build Coastguard Worker
24*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
25*993b0882SAndroid Build Coastguard Worker
26*993b0882SAndroid Build Coastguard Worker namespace {
27*993b0882SAndroid Build Coastguard Worker
Lookup(int byte_start,int byte_end,const absl::string_view token,const std::string & suffix_indicator,const WordpieceVocab * vocab_map,bool * in_vocab)28*993b0882SAndroid Build Coastguard Worker LookupStatus Lookup(int byte_start, int byte_end, const absl::string_view token,
29*993b0882SAndroid Build Coastguard Worker const std::string& suffix_indicator,
30*993b0882SAndroid Build Coastguard Worker const WordpieceVocab* vocab_map, bool* in_vocab) {
31*993b0882SAndroid Build Coastguard Worker int byte_len = byte_end - byte_start;
32*993b0882SAndroid Build Coastguard Worker absl::string_view substr(token.data() + byte_start, byte_len);
33*993b0882SAndroid Build Coastguard Worker std::string lookup_value;
34*993b0882SAndroid Build Coastguard Worker if (byte_start > 0) {
35*993b0882SAndroid Build Coastguard Worker lookup_value = absl::StrCat(suffix_indicator, substr);
36*993b0882SAndroid Build Coastguard Worker } else {
37*993b0882SAndroid Build Coastguard Worker // absl::CopyToString
38*993b0882SAndroid Build Coastguard Worker lookup_value.assign(substr.begin(), substr.end());
39*993b0882SAndroid Build Coastguard Worker }
40*993b0882SAndroid Build Coastguard Worker return vocab_map->Contains(lookup_value, in_vocab);
41*993b0882SAndroid Build Coastguard Worker }
42*993b0882SAndroid Build Coastguard Worker
43*993b0882SAndroid Build Coastguard Worker // Sets byte_end to the longest byte sequence which:
44*993b0882SAndroid Build Coastguard Worker // 1) is a proper UTF8 sequence
45*993b0882SAndroid Build Coastguard Worker // 2) is in the vocab OR if split_unknown_characters is true, is a single
46*993b0882SAndroid Build Coastguard Worker // UTF8 character.
47*993b0882SAndroid Build Coastguard Worker // If no match is found, found_match is set to false.
LongestMatchStartingAt(int byte_start,const absl::string_view token,const std::string & suffix_indicator,const int max_chars_per_subtoken,bool split_unknown_characters,const WordpieceVocab * vocab_map,int * byte_end,bool * found_match,bool * match_is_unknown_character)48*993b0882SAndroid Build Coastguard Worker LookupStatus LongestMatchStartingAt(
49*993b0882SAndroid Build Coastguard Worker int byte_start, const absl::string_view token,
50*993b0882SAndroid Build Coastguard Worker const std::string& suffix_indicator, const int max_chars_per_subtoken,
51*993b0882SAndroid Build Coastguard Worker bool split_unknown_characters, const WordpieceVocab* vocab_map,
52*993b0882SAndroid Build Coastguard Worker int* byte_end, bool* found_match, bool* match_is_unknown_character) {
53*993b0882SAndroid Build Coastguard Worker *match_is_unknown_character = false;
54*993b0882SAndroid Build Coastguard Worker *found_match = false;
55*993b0882SAndroid Build Coastguard Worker const UnicodeText unicode_token =
56*993b0882SAndroid Build Coastguard Worker UTF8ToUnicodeText(token.substr(byte_start), /*do_copy=*/false);
57*993b0882SAndroid Build Coastguard Worker std::vector<int32_t> byte_ends;
58*993b0882SAndroid Build Coastguard Worker int32_t codepoint_offset = byte_start;
59*993b0882SAndroid Build Coastguard Worker for (auto it = unicode_token.begin(); it != unicode_token.end(); ++it) {
60*993b0882SAndroid Build Coastguard Worker codepoint_offset += it.utf8_length();
61*993b0882SAndroid Build Coastguard Worker byte_ends.push_back(codepoint_offset);
62*993b0882SAndroid Build Coastguard Worker if (max_chars_per_subtoken > 0 &&
63*993b0882SAndroid Build Coastguard Worker byte_ends.size() == max_chars_per_subtoken) {
64*993b0882SAndroid Build Coastguard Worker // If the max bytes of a subtoken is known, do not search beyond that
65*993b0882SAndroid Build Coastguard Worker // length.
66*993b0882SAndroid Build Coastguard Worker break;
67*993b0882SAndroid Build Coastguard Worker }
68*993b0882SAndroid Build Coastguard Worker }
69*993b0882SAndroid Build Coastguard Worker int n = byte_ends.size();
70*993b0882SAndroid Build Coastguard Worker for (int i = n - 1; i >= 0; i--) {
71*993b0882SAndroid Build Coastguard Worker bool in_vocab;
72*993b0882SAndroid Build Coastguard Worker auto status = Lookup(byte_start, byte_ends[i], token, suffix_indicator,
73*993b0882SAndroid Build Coastguard Worker vocab_map, &in_vocab);
74*993b0882SAndroid Build Coastguard Worker if (!status.success) return status;
75*993b0882SAndroid Build Coastguard Worker if (in_vocab) {
76*993b0882SAndroid Build Coastguard Worker *byte_end = byte_ends[i];
77*993b0882SAndroid Build Coastguard Worker *found_match = true;
78*993b0882SAndroid Build Coastguard Worker return LookupStatus::OK();
79*993b0882SAndroid Build Coastguard Worker }
80*993b0882SAndroid Build Coastguard Worker if (i == 0 && split_unknown_characters) {
81*993b0882SAndroid Build Coastguard Worker *byte_end = byte_ends[0];
82*993b0882SAndroid Build Coastguard Worker *found_match = true;
83*993b0882SAndroid Build Coastguard Worker *match_is_unknown_character = true;
84*993b0882SAndroid Build Coastguard Worker return LookupStatus::OK();
85*993b0882SAndroid Build Coastguard Worker }
86*993b0882SAndroid Build Coastguard Worker }
87*993b0882SAndroid Build Coastguard Worker return LookupStatus::OK();
88*993b0882SAndroid Build Coastguard Worker }
89*993b0882SAndroid Build Coastguard Worker
90*993b0882SAndroid Build Coastguard Worker // Sets the outputs 'begin_offset', 'end_offset' and 'num_word_pieces' when no
91*993b0882SAndroid Build Coastguard Worker // token is found.
NoTokenFound(const absl::string_view token,bool use_unknown_token,const std::string & unknown_token,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)92*993b0882SAndroid Build Coastguard Worker LookupStatus NoTokenFound(const absl::string_view token, bool use_unknown_token,
93*993b0882SAndroid Build Coastguard Worker const std::string& unknown_token,
94*993b0882SAndroid Build Coastguard Worker std::vector<std::string>* subwords,
95*993b0882SAndroid Build Coastguard Worker std::vector<int>* begin_offset,
96*993b0882SAndroid Build Coastguard Worker std::vector<int>* end_offset, int* num_word_pieces) {
97*993b0882SAndroid Build Coastguard Worker begin_offset->push_back(0);
98*993b0882SAndroid Build Coastguard Worker if (use_unknown_token) {
99*993b0882SAndroid Build Coastguard Worker subwords->push_back(unknown_token);
100*993b0882SAndroid Build Coastguard Worker end_offset->push_back(token.length());
101*993b0882SAndroid Build Coastguard Worker } else {
102*993b0882SAndroid Build Coastguard Worker subwords->emplace_back(token.data(), token.length());
103*993b0882SAndroid Build Coastguard Worker end_offset->push_back(token.length());
104*993b0882SAndroid Build Coastguard Worker }
105*993b0882SAndroid Build Coastguard Worker ++(*num_word_pieces);
106*993b0882SAndroid Build Coastguard Worker
107*993b0882SAndroid Build Coastguard Worker return LookupStatus::OK();
108*993b0882SAndroid Build Coastguard Worker }
109*993b0882SAndroid Build Coastguard Worker
110*993b0882SAndroid Build Coastguard Worker // When a subword is found, this helper function will add the outputs to
111*993b0882SAndroid Build Coastguard Worker // 'subwords', 'begin_offset' and 'end_offset'.
AddWord(const absl::string_view token,int byte_start,int byte_end,const std::string & suffix_indicator,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset)112*993b0882SAndroid Build Coastguard Worker void AddWord(const absl::string_view token, int byte_start, int byte_end,
113*993b0882SAndroid Build Coastguard Worker const std::string& suffix_indicator,
114*993b0882SAndroid Build Coastguard Worker std::vector<std::string>* subwords, std::vector<int>* begin_offset,
115*993b0882SAndroid Build Coastguard Worker std::vector<int>* end_offset) {
116*993b0882SAndroid Build Coastguard Worker begin_offset->push_back(byte_start);
117*993b0882SAndroid Build Coastguard Worker int len = byte_end - byte_start;
118*993b0882SAndroid Build Coastguard Worker
119*993b0882SAndroid Build Coastguard Worker if (byte_start > 0) {
120*993b0882SAndroid Build Coastguard Worker // Prepend suffix_indicator if the token is within a word.
121*993b0882SAndroid Build Coastguard Worker subwords->push_back(::absl::StrCat(
122*993b0882SAndroid Build Coastguard Worker suffix_indicator, absl::string_view(token.data() + byte_start, len)));
123*993b0882SAndroid Build Coastguard Worker } else {
124*993b0882SAndroid Build Coastguard Worker subwords->emplace_back(token.data(), len);
125*993b0882SAndroid Build Coastguard Worker }
126*993b0882SAndroid Build Coastguard Worker end_offset->push_back(byte_end);
127*993b0882SAndroid Build Coastguard Worker }
128*993b0882SAndroid Build Coastguard Worker
129*993b0882SAndroid Build Coastguard Worker // Adds a single unknown character subword, found when split_unknown_characters
130*993b0882SAndroid Build Coastguard Worker // is true.
AddUnknownCharacter(const absl::string_view token,int byte_start,int byte_end,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset)131*993b0882SAndroid Build Coastguard Worker void AddUnknownCharacter(const absl::string_view token, int byte_start,
132*993b0882SAndroid Build Coastguard Worker int byte_end, const std::string& suffix_indicator,
133*993b0882SAndroid Build Coastguard Worker bool use_unknown_token,
134*993b0882SAndroid Build Coastguard Worker const std::string& unknown_token,
135*993b0882SAndroid Build Coastguard Worker std::vector<std::string>* subwords,
136*993b0882SAndroid Build Coastguard Worker std::vector<int>* begin_offset,
137*993b0882SAndroid Build Coastguard Worker std::vector<int>* end_offset) {
138*993b0882SAndroid Build Coastguard Worker begin_offset->push_back(byte_start);
139*993b0882SAndroid Build Coastguard Worker end_offset->push_back(byte_end);
140*993b0882SAndroid Build Coastguard Worker int len = byte_end - byte_start;
141*993b0882SAndroid Build Coastguard Worker if (use_unknown_token) {
142*993b0882SAndroid Build Coastguard Worker if (byte_start > 0) {
143*993b0882SAndroid Build Coastguard Worker // Prepend suffix_indicator if the character is within a word.
144*993b0882SAndroid Build Coastguard Worker subwords->push_back(::absl::StrCat(suffix_indicator, unknown_token));
145*993b0882SAndroid Build Coastguard Worker } else {
146*993b0882SAndroid Build Coastguard Worker subwords->push_back(unknown_token);
147*993b0882SAndroid Build Coastguard Worker }
148*993b0882SAndroid Build Coastguard Worker } else {
149*993b0882SAndroid Build Coastguard Worker if (byte_start > 0) {
150*993b0882SAndroid Build Coastguard Worker // Prepend suffix_indicator if the character is within a word.
151*993b0882SAndroid Build Coastguard Worker subwords->push_back(::absl::StrCat(
152*993b0882SAndroid Build Coastguard Worker suffix_indicator, absl::string_view(token.data() + byte_start, len)));
153*993b0882SAndroid Build Coastguard Worker } else {
154*993b0882SAndroid Build Coastguard Worker subwords->emplace_back(token.data(), len);
155*993b0882SAndroid Build Coastguard Worker }
156*993b0882SAndroid Build Coastguard Worker }
157*993b0882SAndroid Build Coastguard Worker }
158*993b0882SAndroid Build Coastguard Worker
TokenizeL2RGreedy(const absl::string_view token,const int max_bytes_per_token,const int max_chars_per_subtoken,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,bool split_unknown_characters,const WordpieceVocab * vocab_map,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)159*993b0882SAndroid Build Coastguard Worker LookupStatus TokenizeL2RGreedy(
160*993b0882SAndroid Build Coastguard Worker const absl::string_view token, const int max_bytes_per_token,
161*993b0882SAndroid Build Coastguard Worker const int max_chars_per_subtoken, const std::string& suffix_indicator,
162*993b0882SAndroid Build Coastguard Worker bool use_unknown_token, const std::string& unknown_token,
163*993b0882SAndroid Build Coastguard Worker bool split_unknown_characters, const WordpieceVocab* vocab_map,
164*993b0882SAndroid Build Coastguard Worker std::vector<std::string>* subwords, std::vector<int>* begin_offset,
165*993b0882SAndroid Build Coastguard Worker std::vector<int>* end_offset, int* num_word_pieces) {
166*993b0882SAndroid Build Coastguard Worker std::vector<std::string> candidate_subwords;
167*993b0882SAndroid Build Coastguard Worker std::vector<int> candidate_begin_offsets;
168*993b0882SAndroid Build Coastguard Worker std::vector<int> candidate_end_offsets;
169*993b0882SAndroid Build Coastguard Worker const int token_len = token.length();
170*993b0882SAndroid Build Coastguard Worker for (int byte_start = 0; byte_start < token_len;) {
171*993b0882SAndroid Build Coastguard Worker int byte_end;
172*993b0882SAndroid Build Coastguard Worker bool found_subword;
173*993b0882SAndroid Build Coastguard Worker bool match_is_unknown_character;
174*993b0882SAndroid Build Coastguard Worker auto status = LongestMatchStartingAt(
175*993b0882SAndroid Build Coastguard Worker byte_start, token, suffix_indicator, max_chars_per_subtoken,
176*993b0882SAndroid Build Coastguard Worker split_unknown_characters, vocab_map, &byte_end, &found_subword,
177*993b0882SAndroid Build Coastguard Worker &match_is_unknown_character);
178*993b0882SAndroid Build Coastguard Worker if (!status.success) return status;
179*993b0882SAndroid Build Coastguard Worker if (found_subword) {
180*993b0882SAndroid Build Coastguard Worker if (match_is_unknown_character) {
181*993b0882SAndroid Build Coastguard Worker AddUnknownCharacter(token, byte_start, byte_end, suffix_indicator,
182*993b0882SAndroid Build Coastguard Worker use_unknown_token, unknown_token,
183*993b0882SAndroid Build Coastguard Worker &candidate_subwords, &candidate_begin_offsets,
184*993b0882SAndroid Build Coastguard Worker &candidate_end_offsets);
185*993b0882SAndroid Build Coastguard Worker } else {
186*993b0882SAndroid Build Coastguard Worker AddWord(token, byte_start, byte_end, suffix_indicator,
187*993b0882SAndroid Build Coastguard Worker &candidate_subwords, &candidate_begin_offsets,
188*993b0882SAndroid Build Coastguard Worker &candidate_end_offsets);
189*993b0882SAndroid Build Coastguard Worker }
190*993b0882SAndroid Build Coastguard Worker byte_start = byte_end;
191*993b0882SAndroid Build Coastguard Worker } else {
192*993b0882SAndroid Build Coastguard Worker return NoTokenFound(token, use_unknown_token, unknown_token, subwords,
193*993b0882SAndroid Build Coastguard Worker begin_offset, end_offset, num_word_pieces);
194*993b0882SAndroid Build Coastguard Worker }
195*993b0882SAndroid Build Coastguard Worker }
196*993b0882SAndroid Build Coastguard Worker
197*993b0882SAndroid Build Coastguard Worker subwords->insert(subwords->end(), candidate_subwords.begin(),
198*993b0882SAndroid Build Coastguard Worker candidate_subwords.end());
199*993b0882SAndroid Build Coastguard Worker begin_offset->insert(begin_offset->end(), candidate_begin_offsets.begin(),
200*993b0882SAndroid Build Coastguard Worker candidate_begin_offsets.end());
201*993b0882SAndroid Build Coastguard Worker end_offset->insert(end_offset->end(), candidate_end_offsets.begin(),
202*993b0882SAndroid Build Coastguard Worker candidate_end_offsets.end());
203*993b0882SAndroid Build Coastguard Worker *num_word_pieces += candidate_subwords.size();
204*993b0882SAndroid Build Coastguard Worker return LookupStatus::OK();
205*993b0882SAndroid Build Coastguard Worker }
206*993b0882SAndroid Build Coastguard Worker
207*993b0882SAndroid Build Coastguard Worker } // namespace
208*993b0882SAndroid Build Coastguard Worker
WordpieceTokenize(const absl::string_view token,const int max_bytes_per_token,const int max_chars_per_subtoken,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,bool split_unknown_characters,const WordpieceVocab * vocab_map,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)209*993b0882SAndroid Build Coastguard Worker LookupStatus WordpieceTokenize(
210*993b0882SAndroid Build Coastguard Worker const absl::string_view token, const int max_bytes_per_token,
211*993b0882SAndroid Build Coastguard Worker const int max_chars_per_subtoken, const std::string& suffix_indicator,
212*993b0882SAndroid Build Coastguard Worker bool use_unknown_token, const std::string& unknown_token,
213*993b0882SAndroid Build Coastguard Worker bool split_unknown_characters, const WordpieceVocab* vocab_map,
214*993b0882SAndroid Build Coastguard Worker std::vector<std::string>* subwords, std::vector<int>* begin_offset,
215*993b0882SAndroid Build Coastguard Worker std::vector<int>* end_offset, int* num_word_pieces) {
216*993b0882SAndroid Build Coastguard Worker int token_len = token.size();
217*993b0882SAndroid Build Coastguard Worker if (token_len > max_bytes_per_token) {
218*993b0882SAndroid Build Coastguard Worker begin_offset->push_back(0);
219*993b0882SAndroid Build Coastguard Worker *num_word_pieces = 1;
220*993b0882SAndroid Build Coastguard Worker if (use_unknown_token) {
221*993b0882SAndroid Build Coastguard Worker subwords->emplace_back(unknown_token);
222*993b0882SAndroid Build Coastguard Worker } else {
223*993b0882SAndroid Build Coastguard Worker subwords->emplace_back(token);
224*993b0882SAndroid Build Coastguard Worker }
225*993b0882SAndroid Build Coastguard Worker end_offset->push_back(token.size());
226*993b0882SAndroid Build Coastguard Worker return LookupStatus::OK();
227*993b0882SAndroid Build Coastguard Worker }
228*993b0882SAndroid Build Coastguard Worker return TokenizeL2RGreedy(token, max_bytes_per_token, max_chars_per_subtoken,
229*993b0882SAndroid Build Coastguard Worker suffix_indicator, use_unknown_token, unknown_token,
230*993b0882SAndroid Build Coastguard Worker split_unknown_characters, vocab_map, subwords,
231*993b0882SAndroid Build Coastguard Worker begin_offset, end_offset, num_word_pieces);
232*993b0882SAndroid Build Coastguard Worker }
233*993b0882SAndroid Build Coastguard Worker
WordpieceTokenize(const absl::string_view token,const int max_bytes_per_token,const std::string & suffix_indicator,bool use_unknown_token,const std::string & unknown_token,const WordpieceVocab * vocab_map,std::vector<std::string> * subwords,std::vector<int> * begin_offset,std::vector<int> * end_offset,int * num_word_pieces)234*993b0882SAndroid Build Coastguard Worker LookupStatus WordpieceTokenize(
235*993b0882SAndroid Build Coastguard Worker const absl::string_view token, const int max_bytes_per_token,
236*993b0882SAndroid Build Coastguard Worker const std::string& suffix_indicator, bool use_unknown_token,
237*993b0882SAndroid Build Coastguard Worker const std::string& unknown_token, const WordpieceVocab* vocab_map,
238*993b0882SAndroid Build Coastguard Worker std::vector<std::string>* subwords, std::vector<int>* begin_offset,
239*993b0882SAndroid Build Coastguard Worker std::vector<int>* end_offset, int* num_word_pieces) {
240*993b0882SAndroid Build Coastguard Worker return WordpieceTokenize(token, max_bytes_per_token,
241*993b0882SAndroid Build Coastguard Worker /* max_chars_per_subtoken= */ 0, suffix_indicator,
242*993b0882SAndroid Build Coastguard Worker use_unknown_token, unknown_token,
243*993b0882SAndroid Build Coastguard Worker /* split_unknown_characters= */ false, vocab_map,
244*993b0882SAndroid Build Coastguard Worker subwords, begin_offset, end_offset, num_word_pieces);
245*993b0882SAndroid Build Coastguard Worker }
246*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
247