1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <fstream> 12 #include <iostream> 13 #include <string> 14 #include <unordered_map> 15 #include <vector> 16 17 class BasicTokenizer { 18 public: BasicTokenizer(const std::string & file_path)19 explicit BasicTokenizer(const std::string& file_path) { 20 std::ifstream file(file_path); 21 22 if (!file) { 23 std::cerr << "Unable to open file " << file_path << "\n"; 24 exit(9); 25 } 26 std::string str( 27 (std::istreambuf_iterator<char>(file)), 28 std::istreambuf_iterator<char>()); 29 30 size_t i = 0u; 31 i = consume_whitespace(str, i); 32 i = expect(str, i, '{'); 33 34 while (i < str.size() && str[i] != '}') { 35 i = consume_field(str, i); 36 } 37 38 // Build decode map as inverse of encode. 39 for (auto& i : encode_) { 40 decode_[i.second] = i.first; 41 } 42 } 43 encode(const std::string & prompt)44 std::vector<int64_t> encode(const std::string& prompt) { 45 std::vector<std::string> words = parse_prompt(prompt); 46 std::vector<int64_t> result; 47 for (auto word : words) { 48 result.push_back(encode_[word]); 49 } 50 return result; 51 } 52 decode(const std::vector<int64_t> & indices)53 std::string decode(const std::vector<int64_t>& indices) { 54 std::string result; 55 for (const auto& index : indices) { 56 result += decode_[index]; 57 } 58 return result; 59 } 60 61 private: 62 std::unordered_map<std::string, int64_t> encode_; 63 std::unordered_map<int64_t, std::string> decode_; 64 65 // Advance the input string index until a non-whitespace character is found 66 // or it reaches the end of string. consume_whitespace(const std::string & data,size_t i)67 size_t consume_whitespace(const std::string& data, size_t i) { 68 while (i < data.size() && std::isspace(data[i])) { 69 i++; 70 } 71 72 return i; 73 } 74 75 // Consumes an JSON field of the form 76 // "str": id, consume_field(const std::string & data,size_t i)77 size_t consume_field(const std::string& data, size_t i) { 78 i = consume_whitespace(data, i); 79 80 // Parse the key literal. 81 i = expect(data, i, '"'); 82 83 auto in_escape = false; 84 std::string key = ""; 85 while (i < data.size()) { 86 if (in_escape) { 87 key += data[i]; 88 i++; 89 in_escape = false; 90 } else { // !in_escape 91 if (data[i] == '"') { // End of string literal 92 i++; 93 break; 94 } else if (data[i] == '\\') { // Escaped code point 95 in_escape = true; 96 } 97 key += data[i]; 98 i++; 99 } 100 } 101 102 key = post_process_key(key); 103 104 i = expect(data, i, ':'); 105 i = consume_whitespace(data, i); 106 107 // Read unsigned integer value 108 auto value_start = i; 109 while (i < data.size() && std::isdigit(data[i])) { 110 i++; 111 } 112 auto value = static_cast<int64_t>( 113 std::stol(data.substr(value_start, i - value_start))); 114 115 encode_[key] = value; 116 117 i = consume_whitespace(data, i); 118 if (i < data.size() && data[i] == ',') { 119 i++; 120 } 121 122 return i; 123 } 124 125 // Assert that the next character in the input string is equal to c. Increment 126 // the input string index by one. expect(const std::string & data,size_t i,char c)127 size_t expect(const std::string& data, size_t i, char c) { 128 if (i >= data.size() || data[i] != c) { 129 std::cerr << "Invalid tokenizer vocabulary file. Expected '" << c 130 << "' at index " << i << std::endl; 131 exit(1); 132 } 133 134 return i + 1; 135 } 136 post_process_key(std::string key)137 std::string post_process_key(std::string key) { 138 // Replace the unicode characters with the corresponding byte encoding 139 // TODO: adopt byte encoder to handle unicode characters in json file. 140 141 std::unordered_map<std::string, std::string> replacements = { 142 {"\\u0120", " "}, 143 {"\\u010a", "\n"}, 144 }; 145 146 for (const auto& replacement : replacements) { 147 size_t pos = 0; 148 // While loop through all instances of the substring in the string 149 while ((pos = key.find(replacement.first, pos)) != std::string::npos) { 150 key.replace(pos, replacement.first.length(), replacement.second); 151 pos += replacement.second.length(); 152 } 153 } 154 155 // remove duplicate backslashes 156 for (size_t idx = 0; idx < key.length(); idx++) { 157 if (key[idx] == '\\') { 158 key.erase(idx, 1); 159 if (key[idx] == '\\') { 160 // If there are two backslashes, keep the second one 161 idx += 1; 162 } 163 } 164 } 165 166 return key; 167 } parse_prompt(const std::string & prompt)168 std::vector<std::string> parse_prompt(const std::string& prompt) { 169 std::vector<std::string> result; 170 std::string word; 171 for (char c : prompt) { 172 if (c == ' ') { 173 if (!word.empty()) { 174 result.push_back(word); 175 word.clear(); 176 } 177 word += c; 178 } else if (ispunct(c)) { 179 if (!word.empty()) { 180 result.push_back(word); 181 word.clear(); 182 } 183 result.push_back(std::string(1, c)); 184 } else { 185 word += c; 186 } 187 } 188 if (!word.empty()) { 189 result.push_back(word); 190 } 191 return result; 192 } 193 }; 194