xref: /aosp_15_r20/external/libtextclassifier/native/utils/bert_tokenizer.h (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_
18 #define LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_
19 
20 #include <fstream>
21 #include <string>
22 #include <vector>
23 
24 #include "annotator/types.h"
25 #include "utils/wordpiece_tokenizer.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
28 #include "tensorflow_lite_support/cc/utils/common_utils.h"
29 
30 namespace libtextclassifier3 {
31 
32 using ::tflite::support::text::tokenizer::TokenizerResult;
33 using ::tflite::support::utils::LoadVocabFromBuffer;
34 using ::tflite::support::utils::LoadVocabFromFile;
35 
36 constexpr int kDefaultMaxBytesPerToken = 100;
37 constexpr int kDefaultMaxCharsPerSubToken = 100;
38 constexpr char kDefaultSuffixIndicator[] = "##";
39 constexpr bool kDefaultUseUnknownToken = true;
40 constexpr char kDefaultUnknownToken[] = "[UNK]";
41 constexpr bool kDefaultSplitUnknownChars = false;
42 
43 // Result of wordpiece tokenization including subwords and offsets.
44 // Example:
45 // input:                tokenize     me  please
46 // subwords:             token ##ize  me  plea ##se
47 // wp_begin_offset:     [0,      5,   9,  12,    16]
48 // wp_end_offset:       [     5,    8,  11,   16,  18]
49 // row_lengths:         [2,          1,  1]
50 struct WordpieceTokenizerResult
51     : tflite::support::text::tokenizer::TokenizerResult {
52   std::vector<int> wp_begin_offset;
53   std::vector<int> wp_end_offset;
54   std::vector<int> row_lengths;
55 };
56 
57 // Options to create a BertTokenizer.
58 struct BertTokenizerOptions {
59   int max_bytes_per_token = kDefaultMaxBytesPerToken;
60   int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken;
61   std::string suffix_indicator = kDefaultSuffixIndicator;
62   bool use_unknown_token = kDefaultUseUnknownToken;
63   std::string unknown_token = kDefaultUnknownToken;
64   bool split_unknown_chars = kDefaultSplitUnknownChars;
65 };
66 
67 // A flat-hash-map based implementation of WordpieceVocab, used in
68 // BertTokenizer to invoke tensorflow::text::WordpieceTokenize within.
69 class FlatHashMapBackedWordpiece : public WordpieceVocab {
70  public:
71   explicit FlatHashMapBackedWordpiece(const std::vector<std::string>& vocab);
72 
73   LookupStatus Contains(absl::string_view key, bool* value) const override;
74   bool LookupId(absl::string_view key, int* result) const;
75   bool LookupWord(int vocab_id, absl::string_view* result) const;
VocabularySize()76   int VocabularySize() const { return vocab_.size(); }
77 
78  private:
79   // All words indexed position in vocabulary file.
80   std::vector<std::string> vocab_;
81   absl::flat_hash_map<absl::string_view, int> index_map_;
82 };
83 
84 // Wordpiece tokenizer for bert models. Initialized with a vocab file or vector.
85 //
86 // The full tokenization involves two steps: Splitting the input into tokens
87 // (pretokenization) and splitting the tokens into subwords.
88 class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer {
89  public:
90   // Initialize the tokenizer from vocab vector and tokenizer configs.
91   explicit BertTokenizer(const std::vector<std::string>& vocab,
92                          const BertTokenizerOptions& options = {})
93       : vocab_{FlatHashMapBackedWordpiece(vocab)}, options_{options} {}
94 
95   // Initialize the tokenizer from file path to vocab and tokenizer configs.
96   explicit BertTokenizer(const std::string& path_to_vocab,
97                          const BertTokenizerOptions& options = {})
BertTokenizer(LoadVocabFromFile (path_to_vocab),options)98       : BertTokenizer(LoadVocabFromFile(path_to_vocab), options) {}
99 
100   // Initialize the tokenizer from buffer and size of vocab and tokenizer
101   // configs.
102   BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size,
103                 const BertTokenizerOptions& options = {})
BertTokenizer(LoadVocabFromBuffer (vocab_buffer_data,vocab_buffer_size),options)104       : BertTokenizer(LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size),
105                       options) {}
106 
107   // Perform tokenization, first tokenize the input and then find the subwords.
108   // Return tokenized results containing the subwords.
109   TokenizerResult Tokenize(const std::string& input) override;
110 
111   // Perform tokenization, first tokenize the input and then find the subwords.
112   // Return tokenized results containing the subwords and codepoint indices.
113   WordpieceTokenizerResult TokenizeIntoWordpieces(const std::string& input);
114 
115   // Perform tokenization on a single token, return tokenized results containing
116   // the subwords and codepoint indices.
117   WordpieceTokenizerResult TokenizeSingleToken(const std::string& token);
118 
119   // Perform tokenization, return tokenized results containing the subwords and
120   // codepoint indices.
121   WordpieceTokenizerResult TokenizeIntoWordpieces(
122       const std::vector<Token>& tokens);
123 
124   // Check if a certain key is included in the vocab.
Contains(const absl::string_view key,bool * value)125   LookupStatus Contains(const absl::string_view key, bool* value) const {
126     return vocab_.Contains(key, value);
127   }
128 
129   // Find the id of a wordpiece.
LookupId(absl::string_view key,int * result)130   bool LookupId(absl::string_view key, int* result) const override {
131     return vocab_.LookupId(key, result);
132   }
133 
134   // Find the wordpiece from an id.
LookupWord(int vocab_id,absl::string_view * result)135   bool LookupWord(int vocab_id, absl::string_view* result) const override {
136     return vocab_.LookupWord(vocab_id, result);
137   }
138 
VocabularySize()139   int VocabularySize() const { return vocab_.VocabularySize(); }
140 
141   static std::vector<std::string> PreTokenize(const absl::string_view input);
142 
143  private:
144   FlatHashMapBackedWordpiece vocab_;
145   BertTokenizerOptions options_;
146 };
147 
148 }  // namespace libtextclassifier3
149 
150 #endif  // LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_
151