1*993b0882SAndroid Build Coastguard Worker /* 2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project 3*993b0882SAndroid Build Coastguard Worker * 4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*993b0882SAndroid Build Coastguard Worker * 8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*993b0882SAndroid Build Coastguard Worker * 10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*993b0882SAndroid Build Coastguard Worker * limitations under the License. 15*993b0882SAndroid Build Coastguard Worker */ 16*993b0882SAndroid Build Coastguard Worker 17*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_ 18*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_ 19*993b0882SAndroid Build Coastguard Worker 20*993b0882SAndroid Build Coastguard Worker #include <fstream> 21*993b0882SAndroid Build Coastguard Worker #include <string> 22*993b0882SAndroid Build Coastguard Worker #include <vector> 23*993b0882SAndroid Build Coastguard Worker 24*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h" 25*993b0882SAndroid Build Coastguard Worker #include "utils/wordpiece_tokenizer.h" 26*993b0882SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h" 27*993b0882SAndroid Build Coastguard Worker #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" 28*993b0882SAndroid Build Coastguard Worker #include "tensorflow_lite_support/cc/utils/common_utils.h" 29*993b0882SAndroid Build Coastguard Worker 30*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 { 31*993b0882SAndroid Build Coastguard Worker 32*993b0882SAndroid Build Coastguard Worker using ::tflite::support::text::tokenizer::TokenizerResult; 33*993b0882SAndroid Build Coastguard Worker using ::tflite::support::utils::LoadVocabFromBuffer; 34*993b0882SAndroid Build Coastguard Worker using ::tflite::support::utils::LoadVocabFromFile; 35*993b0882SAndroid Build Coastguard Worker 36*993b0882SAndroid Build Coastguard Worker constexpr int kDefaultMaxBytesPerToken = 100; 37*993b0882SAndroid Build Coastguard Worker constexpr int kDefaultMaxCharsPerSubToken = 100; 38*993b0882SAndroid Build Coastguard Worker constexpr char kDefaultSuffixIndicator[] = "##"; 39*993b0882SAndroid Build Coastguard Worker constexpr bool kDefaultUseUnknownToken = true; 40*993b0882SAndroid Build Coastguard Worker constexpr char kDefaultUnknownToken[] = "[UNK]"; 41*993b0882SAndroid Build Coastguard Worker constexpr bool kDefaultSplitUnknownChars = false; 42*993b0882SAndroid Build Coastguard Worker 43*993b0882SAndroid Build Coastguard Worker // Result of wordpiece tokenization including subwords and offsets. 44*993b0882SAndroid Build Coastguard Worker // Example: 45*993b0882SAndroid Build Coastguard Worker // input: tokenize me please 46*993b0882SAndroid Build Coastguard Worker // subwords: token ##ize me plea ##se 47*993b0882SAndroid Build Coastguard Worker // wp_begin_offset: [0, 5, 9, 12, 16] 48*993b0882SAndroid Build Coastguard Worker // wp_end_offset: [ 5, 8, 11, 16, 18] 49*993b0882SAndroid Build Coastguard Worker // row_lengths: [2, 1, 1] 50*993b0882SAndroid Build Coastguard Worker struct WordpieceTokenizerResult 51*993b0882SAndroid Build Coastguard Worker : tflite::support::text::tokenizer::TokenizerResult { 52*993b0882SAndroid Build Coastguard Worker std::vector<int> wp_begin_offset; 53*993b0882SAndroid Build Coastguard Worker std::vector<int> wp_end_offset; 54*993b0882SAndroid Build Coastguard Worker std::vector<int> row_lengths; 55*993b0882SAndroid Build Coastguard Worker }; 56*993b0882SAndroid Build Coastguard Worker 57*993b0882SAndroid Build Coastguard Worker // Options to create a BertTokenizer. 58*993b0882SAndroid Build Coastguard Worker struct BertTokenizerOptions { 59*993b0882SAndroid Build Coastguard Worker int max_bytes_per_token = kDefaultMaxBytesPerToken; 60*993b0882SAndroid Build Coastguard Worker int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken; 61*993b0882SAndroid Build Coastguard Worker std::string suffix_indicator = kDefaultSuffixIndicator; 62*993b0882SAndroid Build Coastguard Worker bool use_unknown_token = kDefaultUseUnknownToken; 63*993b0882SAndroid Build Coastguard Worker std::string unknown_token = kDefaultUnknownToken; 64*993b0882SAndroid Build Coastguard Worker bool split_unknown_chars = kDefaultSplitUnknownChars; 65*993b0882SAndroid Build Coastguard Worker }; 66*993b0882SAndroid Build Coastguard Worker 67*993b0882SAndroid Build Coastguard Worker // A flat-hash-map based implementation of WordpieceVocab, used in 68*993b0882SAndroid Build Coastguard Worker // BertTokenizer to invoke tensorflow::text::WordpieceTokenize within. 69*993b0882SAndroid Build Coastguard Worker class FlatHashMapBackedWordpiece : public WordpieceVocab { 70*993b0882SAndroid Build Coastguard Worker public: 71*993b0882SAndroid Build Coastguard Worker explicit FlatHashMapBackedWordpiece(const std::vector<std::string>& vocab); 72*993b0882SAndroid Build Coastguard Worker 73*993b0882SAndroid Build Coastguard Worker LookupStatus Contains(absl::string_view key, bool* value) const override; 74*993b0882SAndroid Build Coastguard Worker bool LookupId(absl::string_view key, int* result) const; 75*993b0882SAndroid Build Coastguard Worker bool LookupWord(int vocab_id, absl::string_view* result) const; VocabularySize()76*993b0882SAndroid Build Coastguard Worker int VocabularySize() const { return vocab_.size(); } 77*993b0882SAndroid Build Coastguard Worker 78*993b0882SAndroid Build Coastguard Worker private: 79*993b0882SAndroid Build Coastguard Worker // All words indexed position in vocabulary file. 80*993b0882SAndroid Build Coastguard Worker std::vector<std::string> vocab_; 81*993b0882SAndroid Build Coastguard Worker absl::flat_hash_map<absl::string_view, int> index_map_; 82*993b0882SAndroid Build Coastguard Worker }; 83*993b0882SAndroid Build Coastguard Worker 84*993b0882SAndroid Build Coastguard Worker // Wordpiece tokenizer for bert models. Initialized with a vocab file or vector. 85*993b0882SAndroid Build Coastguard Worker // 86*993b0882SAndroid Build Coastguard Worker // The full tokenization involves two steps: Splitting the input into tokens 87*993b0882SAndroid Build Coastguard Worker // (pretokenization) and splitting the tokens into subwords. 88*993b0882SAndroid Build Coastguard Worker class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer { 89*993b0882SAndroid Build Coastguard Worker public: 90*993b0882SAndroid Build Coastguard Worker // Initialize the tokenizer from vocab vector and tokenizer configs. 91*993b0882SAndroid Build Coastguard Worker explicit BertTokenizer(const std::vector<std::string>& vocab, 92*993b0882SAndroid Build Coastguard Worker const BertTokenizerOptions& options = {}) 93*993b0882SAndroid Build Coastguard Worker : vocab_{FlatHashMapBackedWordpiece(vocab)}, options_{options} {} 94*993b0882SAndroid Build Coastguard Worker 95*993b0882SAndroid Build Coastguard Worker // Initialize the tokenizer from file path to vocab and tokenizer configs. 96*993b0882SAndroid Build Coastguard Worker explicit BertTokenizer(const std::string& path_to_vocab, 97*993b0882SAndroid Build Coastguard Worker const BertTokenizerOptions& options = {}) BertTokenizer(LoadVocabFromFile (path_to_vocab),options)98*993b0882SAndroid Build Coastguard Worker : BertTokenizer(LoadVocabFromFile(path_to_vocab), options) {} 99*993b0882SAndroid Build Coastguard Worker 100*993b0882SAndroid Build Coastguard Worker // Initialize the tokenizer from buffer and size of vocab and tokenizer 101*993b0882SAndroid Build Coastguard Worker // configs. 102*993b0882SAndroid Build Coastguard Worker BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size, 103*993b0882SAndroid Build Coastguard Worker const BertTokenizerOptions& options = {}) BertTokenizer(LoadVocabFromBuffer (vocab_buffer_data,vocab_buffer_size),options)104*993b0882SAndroid Build Coastguard Worker : BertTokenizer(LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size), 105*993b0882SAndroid Build Coastguard Worker options) {} 106*993b0882SAndroid Build Coastguard Worker 107*993b0882SAndroid Build Coastguard Worker // Perform tokenization, first tokenize the input and then find the subwords. 108*993b0882SAndroid Build Coastguard Worker // Return tokenized results containing the subwords. 109*993b0882SAndroid Build Coastguard Worker TokenizerResult Tokenize(const std::string& input) override; 110*993b0882SAndroid Build Coastguard Worker 111*993b0882SAndroid Build Coastguard Worker // Perform tokenization, first tokenize the input and then find the subwords. 112*993b0882SAndroid Build Coastguard Worker // Return tokenized results containing the subwords and codepoint indices. 113*993b0882SAndroid Build Coastguard Worker WordpieceTokenizerResult TokenizeIntoWordpieces(const std::string& input); 114*993b0882SAndroid Build Coastguard Worker 115*993b0882SAndroid Build Coastguard Worker // Perform tokenization on a single token, return tokenized results containing 116*993b0882SAndroid Build Coastguard Worker // the subwords and codepoint indices. 117*993b0882SAndroid Build Coastguard Worker WordpieceTokenizerResult TokenizeSingleToken(const std::string& token); 118*993b0882SAndroid Build Coastguard Worker 119*993b0882SAndroid Build Coastguard Worker // Perform tokenization, return tokenized results containing the subwords and 120*993b0882SAndroid Build Coastguard Worker // codepoint indices. 121*993b0882SAndroid Build Coastguard Worker WordpieceTokenizerResult TokenizeIntoWordpieces( 122*993b0882SAndroid Build Coastguard Worker const std::vector<Token>& tokens); 123*993b0882SAndroid Build Coastguard Worker 124*993b0882SAndroid Build Coastguard Worker // Check if a certain key is included in the vocab. Contains(const absl::string_view key,bool * value)125*993b0882SAndroid Build Coastguard Worker LookupStatus Contains(const absl::string_view key, bool* value) const { 126*993b0882SAndroid Build Coastguard Worker return vocab_.Contains(key, value); 127*993b0882SAndroid Build Coastguard Worker } 128*993b0882SAndroid Build Coastguard Worker 129*993b0882SAndroid Build Coastguard Worker // Find the id of a wordpiece. LookupId(absl::string_view key,int * result)130*993b0882SAndroid Build Coastguard Worker bool LookupId(absl::string_view key, int* result) const override { 131*993b0882SAndroid Build Coastguard Worker return vocab_.LookupId(key, result); 132*993b0882SAndroid Build Coastguard Worker } 133*993b0882SAndroid Build Coastguard Worker 134*993b0882SAndroid Build Coastguard Worker // Find the wordpiece from an id. LookupWord(int vocab_id,absl::string_view * result)135*993b0882SAndroid Build Coastguard Worker bool LookupWord(int vocab_id, absl::string_view* result) const override { 136*993b0882SAndroid Build Coastguard Worker return vocab_.LookupWord(vocab_id, result); 137*993b0882SAndroid Build Coastguard Worker } 138*993b0882SAndroid Build Coastguard Worker VocabularySize()139*993b0882SAndroid Build Coastguard Worker int VocabularySize() const { return vocab_.VocabularySize(); } 140*993b0882SAndroid Build Coastguard Worker 141*993b0882SAndroid Build Coastguard Worker static std::vector<std::string> PreTokenize(const absl::string_view input); 142*993b0882SAndroid Build Coastguard Worker 143*993b0882SAndroid Build Coastguard Worker private: 144*993b0882SAndroid Build Coastguard Worker FlatHashMapBackedWordpiece vocab_; 145*993b0882SAndroid Build Coastguard Worker BertTokenizerOptions options_; 146*993b0882SAndroid Build Coastguard Worker }; 147*993b0882SAndroid Build Coastguard Worker 148*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3 149*993b0882SAndroid Build Coastguard Worker 150*993b0882SAndroid Build Coastguard Worker #endif // LIBTEXTCLASSIFIER_UTILS_BERT_TOKENIZER_H_ 151