1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_ 18 #define LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_ 19 20 #include <fstream> 21 #include <string> 22 #include <vector> 23 24 #include "annotator/types.h" 25 #include "utils/wordpiece_tokenizer.h" 26 #include "absl/container/flat_hash_map.h" 27 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" 28 #include "tensorflow_lite_support/cc/utils/common_utils.h" 29 30 namespace libtextclassifier3 { 31 32 using ::tflite::support::text::tokenizer::TokenizerResult; 33 using ::tflite::support::utils::LoadVocabFromBuffer; 34 using ::tflite::support::utils::LoadVocabFromFile; 35 36 constexpr int kDefaultMaxBytesPerToken = 100; 37 constexpr int kDefaultMaxCharsPerSubToken = 100; 38 constexpr char kDefaultSuffixIndicator[] = "##"; 39 constexpr bool kDefaultUseUnknownToken = true; 40 constexpr char kDefaultUnknownToken[] = "[UNK]"; 41 constexpr bool kDefaultSplitUnknownChars = false; 42 43 // Result of wordpiece tokenization including subwords and offsets. 44 // Example: 45 // input: tokenize me please 46 // subwords: token ##ize me plea ##se 47 // wp_begin_offset: [0, 5, 9, 12, 16] 48 // wp_end_offset: [ 5, 8, 11, 16, 18] 49 // row_lengths: [2, 1, 1] 50 struct WordpieceTokenizerResult 51 : tflite::support::text::tokenizer::TokenizerResult { 52 std::vector<int> wp_begin_offset; 53 std::vector<int> wp_end_offset; 54 std::vector<int> row_lengths; 55 }; 56 57 // Options to create a BertTokenizer. 58 struct BertTokenizerOptions { 59 int max_bytes_per_token = kDefaultMaxBytesPerToken; 60 int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken; 61 std::string suffix_indicator = kDefaultSuffixIndicator; 62 bool use_unknown_token = kDefaultUseUnknownToken; 63 std::string unknown_token = kDefaultUnknownToken; 64 bool split_unknown_chars = kDefaultSplitUnknownChars; 65 }; 66 67 // A flat-hash-map based implementation of WordpieceVocab, used in 68 // BertTokenizer to invoke tensorflow::text::WordpieceTokenize within. 69 class FlatHashMapBackedWordpiece : public WordpieceVocab { 70 public: 71 explicit FlatHashMapBackedWordpiece(const std::vector<std::string>& vocab); 72 73 LookupStatus Contains(absl::string_view key, bool* value) const override; 74 bool LookupId(absl::string_view key, int* result) const; 75 bool LookupWord(int vocab_id, absl::string_view* result) const; VocabularySize()76 int VocabularySize() const { return vocab_.size(); } 77 78 private: 79 // All words indexed position in vocabulary file. 80 std::vector<std::string> vocab_; 81 absl::flat_hash_map<absl::string_view, int> index_map_; 82 }; 83 84 // Wordpiece tokenizer for bert models. Initialized with a vocab file or vector. 85 // 86 // The full tokenization involves two steps: Splitting the input into tokens 87 // (pretokenization) and splitting the tokens into subwords. 88 class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer { 89 public: 90 // Initialize the tokenizer from vocab vector and tokenizer configs. 91 explicit BertTokenizer(const std::vector<std::string>& vocab, 92 const BertTokenizerOptions& options = {}) 93 : vocab_{FlatHashMapBackedWordpiece(vocab)}, options_{options} {} 94 95 // Initialize the tokenizer from file path to vocab and tokenizer configs. 96 explicit BertTokenizer(const std::string& path_to_vocab, 97 const BertTokenizerOptions& options = {}) BertTokenizer(LoadVocabFromFile (path_to_vocab),options)98 : BertTokenizer(LoadVocabFromFile(path_to_vocab), options) {} 99 100 // Initialize the tokenizer from buffer and size of vocab and tokenizer 101 // configs. 102 BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size, 103 const BertTokenizerOptions& options = {}) BertTokenizer(LoadVocabFromBuffer (vocab_buffer_data,vocab_buffer_size),options)104 : BertTokenizer(LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size), 105 options) {} 106 107 // Perform tokenization, first tokenize the input and then find the subwords. 108 // Return tokenized results containing the subwords. 109 TokenizerResult Tokenize(const std::string& input) override; 110 111 // Perform tokenization, first tokenize the input and then find the subwords. 112 // Return tokenized results containing the subwords and codepoint indices. 113 WordpieceTokenizerResult TokenizeIntoWordpieces(const std::string& input); 114 115 // Perform tokenization on a single token, return tokenized results containing 116 // the subwords and codepoint indices. 117 WordpieceTokenizerResult TokenizeSingleToken(const std::string& token); 118 119 // Perform tokenization, return tokenized results containing the subwords and 120 // codepoint indices. 121 WordpieceTokenizerResult TokenizeIntoWordpieces( 122 const std::vector<Token>& tokens); 123 124 // Check if a certain key is included in the vocab. Contains(const absl::string_view key,bool * value)125 LookupStatus Contains(const absl::string_view key, bool* value) const { 126 return vocab_.Contains(key, value); 127 } 128 129 // Find the id of a wordpiece. LookupId(absl::string_view key,int * result)130 bool LookupId(absl::string_view key, int* result) const override { 131 return vocab_.LookupId(key, result); 132 } 133 134 // Find the wordpiece from an id. LookupWord(int vocab_id,absl::string_view * result)135 bool LookupWord(int vocab_id, absl::string_view* result) const override { 136 return vocab_.LookupWord(vocab_id, result); 137 } 138 VocabularySize()139 int VocabularySize() const { return vocab_.VocabularySize(); } 140 141 static std::vector<std::string> PreTokenize(const absl::string_view input); 142 143 private: 144 FlatHashMapBackedWordpiece vocab_; 145 BertTokenizerOptions options_; 146 }; 147 148 } // namespace libtextclassifier3 149 150 #endif // LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_ 151