xref: /aosp_15_r20/external/libtextclassifier/native/utils/bert_tokenizer.h (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_
18*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_
19*993b0882SAndroid Build Coastguard Worker 
20*993b0882SAndroid Build Coastguard Worker #include <fstream>
21*993b0882SAndroid Build Coastguard Worker #include <string>
22*993b0882SAndroid Build Coastguard Worker #include <vector>
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/wordpiece_tokenizer.h"
26*993b0882SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h"
27*993b0882SAndroid Build Coastguard Worker #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
28*993b0882SAndroid Build Coastguard Worker #include "tensorflow_lite_support/cc/utils/common_utils.h"
29*993b0882SAndroid Build Coastguard Worker 
30*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
31*993b0882SAndroid Build Coastguard Worker 
32*993b0882SAndroid Build Coastguard Worker using ::tflite::support::text::tokenizer::TokenizerResult;
33*993b0882SAndroid Build Coastguard Worker using ::tflite::support::utils::LoadVocabFromBuffer;
34*993b0882SAndroid Build Coastguard Worker using ::tflite::support::utils::LoadVocabFromFile;
35*993b0882SAndroid Build Coastguard Worker 
36*993b0882SAndroid Build Coastguard Worker constexpr int kDefaultMaxBytesPerToken = 100;
37*993b0882SAndroid Build Coastguard Worker constexpr int kDefaultMaxCharsPerSubToken = 100;
38*993b0882SAndroid Build Coastguard Worker constexpr char kDefaultSuffixIndicator[] = "##";
39*993b0882SAndroid Build Coastguard Worker constexpr bool kDefaultUseUnknownToken = true;
40*993b0882SAndroid Build Coastguard Worker constexpr char kDefaultUnknownToken[] = "[UNK]";
41*993b0882SAndroid Build Coastguard Worker constexpr bool kDefaultSplitUnknownChars = false;
42*993b0882SAndroid Build Coastguard Worker 
43*993b0882SAndroid Build Coastguard Worker // Result of wordpiece tokenization including subwords and offsets.
44*993b0882SAndroid Build Coastguard Worker // Example:
45*993b0882SAndroid Build Coastguard Worker // input:                tokenize     me  please
46*993b0882SAndroid Build Coastguard Worker // subwords:             token ##ize  me  plea ##se
47*993b0882SAndroid Build Coastguard Worker // wp_begin_offset:     [0,      5,   9,  12,    16]
48*993b0882SAndroid Build Coastguard Worker // wp_end_offset:       [     5,    8,  11,   16,  18]
49*993b0882SAndroid Build Coastguard Worker // row_lengths:         [2,          1,  1]
50*993b0882SAndroid Build Coastguard Worker struct WordpieceTokenizerResult
51*993b0882SAndroid Build Coastguard Worker     : tflite::support::text::tokenizer::TokenizerResult {
52*993b0882SAndroid Build Coastguard Worker   std::vector<int> wp_begin_offset;
53*993b0882SAndroid Build Coastguard Worker   std::vector<int> wp_end_offset;
54*993b0882SAndroid Build Coastguard Worker   std::vector<int> row_lengths;
55*993b0882SAndroid Build Coastguard Worker };
56*993b0882SAndroid Build Coastguard Worker 
57*993b0882SAndroid Build Coastguard Worker // Options to create a BertTokenizer.
58*993b0882SAndroid Build Coastguard Worker struct BertTokenizerOptions {
59*993b0882SAndroid Build Coastguard Worker   int max_bytes_per_token = kDefaultMaxBytesPerToken;
60*993b0882SAndroid Build Coastguard Worker   int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken;
61*993b0882SAndroid Build Coastguard Worker   std::string suffix_indicator = kDefaultSuffixIndicator;
62*993b0882SAndroid Build Coastguard Worker   bool use_unknown_token = kDefaultUseUnknownToken;
63*993b0882SAndroid Build Coastguard Worker   std::string unknown_token = kDefaultUnknownToken;
64*993b0882SAndroid Build Coastguard Worker   bool split_unknown_chars = kDefaultSplitUnknownChars;
65*993b0882SAndroid Build Coastguard Worker };
66*993b0882SAndroid Build Coastguard Worker 
67*993b0882SAndroid Build Coastguard Worker // A flat-hash-map based implementation of WordpieceVocab, used in
68*993b0882SAndroid Build Coastguard Worker // BertTokenizer to invoke tensorflow::text::WordpieceTokenize within.
69*993b0882SAndroid Build Coastguard Worker class FlatHashMapBackedWordpiece : public WordpieceVocab {
70*993b0882SAndroid Build Coastguard Worker  public:
71*993b0882SAndroid Build Coastguard Worker   explicit FlatHashMapBackedWordpiece(const std::vector<std::string>& vocab);
72*993b0882SAndroid Build Coastguard Worker 
73*993b0882SAndroid Build Coastguard Worker   LookupStatus Contains(absl::string_view key, bool* value) const override;
74*993b0882SAndroid Build Coastguard Worker   bool LookupId(absl::string_view key, int* result) const;
75*993b0882SAndroid Build Coastguard Worker   bool LookupWord(int vocab_id, absl::string_view* result) const;
VocabularySize()76*993b0882SAndroid Build Coastguard Worker   int VocabularySize() const { return vocab_.size(); }
77*993b0882SAndroid Build Coastguard Worker 
78*993b0882SAndroid Build Coastguard Worker  private:
79*993b0882SAndroid Build Coastguard Worker   // All words indexed position in vocabulary file.
80*993b0882SAndroid Build Coastguard Worker   std::vector<std::string> vocab_;
81*993b0882SAndroid Build Coastguard Worker   absl::flat_hash_map<absl::string_view, int> index_map_;
82*993b0882SAndroid Build Coastguard Worker };
83*993b0882SAndroid Build Coastguard Worker 
84*993b0882SAndroid Build Coastguard Worker // Wordpiece tokenizer for bert models. Initialized with a vocab file or vector.
85*993b0882SAndroid Build Coastguard Worker //
86*993b0882SAndroid Build Coastguard Worker // The full tokenization involves two steps: Splitting the input into tokens
87*993b0882SAndroid Build Coastguard Worker // (pretokenization) and splitting the tokens into subwords.
88*993b0882SAndroid Build Coastguard Worker class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer {
89*993b0882SAndroid Build Coastguard Worker  public:
90*993b0882SAndroid Build Coastguard Worker   // Initialize the tokenizer from vocab vector and tokenizer configs.
91*993b0882SAndroid Build Coastguard Worker   explicit BertTokenizer(const std::vector<std::string>& vocab,
92*993b0882SAndroid Build Coastguard Worker                          const BertTokenizerOptions& options = {})
93*993b0882SAndroid Build Coastguard Worker       : vocab_{FlatHashMapBackedWordpiece(vocab)}, options_{options} {}
94*993b0882SAndroid Build Coastguard Worker 
95*993b0882SAndroid Build Coastguard Worker   // Initialize the tokenizer from file path to vocab and tokenizer configs.
96*993b0882SAndroid Build Coastguard Worker   explicit BertTokenizer(const std::string& path_to_vocab,
97*993b0882SAndroid Build Coastguard Worker                          const BertTokenizerOptions& options = {})
BertTokenizer(LoadVocabFromFile (path_to_vocab),options)98*993b0882SAndroid Build Coastguard Worker       : BertTokenizer(LoadVocabFromFile(path_to_vocab), options) {}
99*993b0882SAndroid Build Coastguard Worker 
100*993b0882SAndroid Build Coastguard Worker   // Initialize the tokenizer from buffer and size of vocab and tokenizer
101*993b0882SAndroid Build Coastguard Worker   // configs.
102*993b0882SAndroid Build Coastguard Worker   BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size,
103*993b0882SAndroid Build Coastguard Worker                 const BertTokenizerOptions& options = {})
BertTokenizer(LoadVocabFromBuffer (vocab_buffer_data,vocab_buffer_size),options)104*993b0882SAndroid Build Coastguard Worker       : BertTokenizer(LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size),
105*993b0882SAndroid Build Coastguard Worker                       options) {}
106*993b0882SAndroid Build Coastguard Worker 
107*993b0882SAndroid Build Coastguard Worker   // Perform tokenization, first tokenize the input and then find the subwords.
108*993b0882SAndroid Build Coastguard Worker   // Return tokenized results containing the subwords.
109*993b0882SAndroid Build Coastguard Worker   TokenizerResult Tokenize(const std::string& input) override;
110*993b0882SAndroid Build Coastguard Worker 
111*993b0882SAndroid Build Coastguard Worker   // Perform tokenization, first tokenize the input and then find the subwords.
112*993b0882SAndroid Build Coastguard Worker   // Return tokenized results containing the subwords and codepoint indices.
113*993b0882SAndroid Build Coastguard Worker   WordpieceTokenizerResult TokenizeIntoWordpieces(const std::string& input);
114*993b0882SAndroid Build Coastguard Worker 
115*993b0882SAndroid Build Coastguard Worker   // Perform tokenization on a single token, return tokenized results containing
116*993b0882SAndroid Build Coastguard Worker   // the subwords and codepoint indices.
117*993b0882SAndroid Build Coastguard Worker   WordpieceTokenizerResult TokenizeSingleToken(const std::string& token);
118*993b0882SAndroid Build Coastguard Worker 
119*993b0882SAndroid Build Coastguard Worker   // Perform tokenization, return tokenized results containing the subwords and
120*993b0882SAndroid Build Coastguard Worker   // codepoint indices.
121*993b0882SAndroid Build Coastguard Worker   WordpieceTokenizerResult TokenizeIntoWordpieces(
122*993b0882SAndroid Build Coastguard Worker       const std::vector<Token>& tokens);
123*993b0882SAndroid Build Coastguard Worker 
124*993b0882SAndroid Build Coastguard Worker   // Check if a certain key is included in the vocab.
Contains(const absl::string_view key,bool * value)125*993b0882SAndroid Build Coastguard Worker   LookupStatus Contains(const absl::string_view key, bool* value) const {
126*993b0882SAndroid Build Coastguard Worker     return vocab_.Contains(key, value);
127*993b0882SAndroid Build Coastguard Worker   }
128*993b0882SAndroid Build Coastguard Worker 
129*993b0882SAndroid Build Coastguard Worker   // Find the id of a wordpiece.
LookupId(absl::string_view key,int * result)130*993b0882SAndroid Build Coastguard Worker   bool LookupId(absl::string_view key, int* result) const override {
131*993b0882SAndroid Build Coastguard Worker     return vocab_.LookupId(key, result);
132*993b0882SAndroid Build Coastguard Worker   }
133*993b0882SAndroid Build Coastguard Worker 
134*993b0882SAndroid Build Coastguard Worker   // Find the wordpiece from an id.
LookupWord(int vocab_id,absl::string_view * result)135*993b0882SAndroid Build Coastguard Worker   bool LookupWord(int vocab_id, absl::string_view* result) const override {
136*993b0882SAndroid Build Coastguard Worker     return vocab_.LookupWord(vocab_id, result);
137*993b0882SAndroid Build Coastguard Worker   }
138*993b0882SAndroid Build Coastguard Worker 
VocabularySize()139*993b0882SAndroid Build Coastguard Worker   int VocabularySize() const { return vocab_.VocabularySize(); }
140*993b0882SAndroid Build Coastguard Worker 
141*993b0882SAndroid Build Coastguard Worker   static std::vector<std::string> PreTokenize(const absl::string_view input);
142*993b0882SAndroid Build Coastguard Worker 
143*993b0882SAndroid Build Coastguard Worker  private:
144*993b0882SAndroid Build Coastguard Worker   FlatHashMapBackedWordpiece vocab_;
145*993b0882SAndroid Build Coastguard Worker   BertTokenizerOptions options_;
146*993b0882SAndroid Build Coastguard Worker };
147*993b0882SAndroid Build Coastguard Worker 
148*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
149*993b0882SAndroid Build Coastguard Worker 
150*993b0882SAndroid Build Coastguard Worker #endif  // LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_
151