xref: /aosp_15_r20/external/executorch/examples/llm_manual/basic_tokenizer.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <fstream>
12 #include <iostream>
13 #include <string>
14 #include <unordered_map>
15 #include <vector>
16 
17 class BasicTokenizer {
18  public:
BasicTokenizer(const std::string & file_path)19   explicit BasicTokenizer(const std::string& file_path) {
20     std::ifstream file(file_path);
21 
22     if (!file) {
23       std::cerr << "Unable to open file " << file_path << "\n";
24       exit(9);
25     }
26     std::string str(
27         (std::istreambuf_iterator<char>(file)),
28         std::istreambuf_iterator<char>());
29 
30     size_t i = 0u;
31     i = consume_whitespace(str, i);
32     i = expect(str, i, '{');
33 
34     while (i < str.size() && str[i] != '}') {
35       i = consume_field(str, i);
36     }
37 
38     // Build decode map as inverse of encode.
39     for (auto& i : encode_) {
40       decode_[i.second] = i.first;
41     }
42   }
43 
encode(const std::string & prompt)44   std::vector<int64_t> encode(const std::string& prompt) {
45     std::vector<std::string> words = parse_prompt(prompt);
46     std::vector<int64_t> result;
47     for (auto word : words) {
48       result.push_back(encode_[word]);
49     }
50     return result;
51   }
52 
decode(const std::vector<int64_t> & indices)53   std::string decode(const std::vector<int64_t>& indices) {
54     std::string result;
55     for (const auto& index : indices) {
56       result += decode_[index];
57     }
58     return result;
59   }
60 
61  private:
62   std::unordered_map<std::string, int64_t> encode_;
63   std::unordered_map<int64_t, std::string> decode_;
64 
65   // Advance the input string index until a non-whitespace character is found
66   // or it reaches the end of string.
consume_whitespace(const std::string & data,size_t i)67   size_t consume_whitespace(const std::string& data, size_t i) {
68     while (i < data.size() && std::isspace(data[i])) {
69       i++;
70     }
71 
72     return i;
73   }
74 
75   // Consumes an JSON field of the form
76   //  "str": id,
consume_field(const std::string & data,size_t i)77   size_t consume_field(const std::string& data, size_t i) {
78     i = consume_whitespace(data, i);
79 
80     // Parse the key literal.
81     i = expect(data, i, '"');
82 
83     auto in_escape = false;
84     std::string key = "";
85     while (i < data.size()) {
86       if (in_escape) {
87         key += data[i];
88         i++;
89         in_escape = false;
90       } else { // !in_escape
91         if (data[i] == '"') { // End of string literal
92           i++;
93           break;
94         } else if (data[i] == '\\') { // Escaped code point
95           in_escape = true;
96         }
97         key += data[i];
98         i++;
99       }
100     }
101 
102     key = post_process_key(key);
103 
104     i = expect(data, i, ':');
105     i = consume_whitespace(data, i);
106 
107     // Read unsigned integer value
108     auto value_start = i;
109     while (i < data.size() && std::isdigit(data[i])) {
110       i++;
111     }
112     auto value = static_cast<int64_t>(
113         std::stol(data.substr(value_start, i - value_start)));
114 
115     encode_[key] = value;
116 
117     i = consume_whitespace(data, i);
118     if (i < data.size() && data[i] == ',') {
119       i++;
120     }
121 
122     return i;
123   }
124 
125   // Assert that the next character in the input string is equal to c. Increment
126   // the input string index by one.
expect(const std::string & data,size_t i,char c)127   size_t expect(const std::string& data, size_t i, char c) {
128     if (i >= data.size() || data[i] != c) {
129       std::cerr << "Invalid tokenizer vocabulary file. Expected '" << c
130                 << "' at index " << i << std::endl;
131       exit(1);
132     }
133 
134     return i + 1;
135   }
136 
post_process_key(std::string key)137   std::string post_process_key(std::string key) {
138     // Replace the unicode characters with the corresponding byte encoding
139     // TODO: adopt byte encoder to handle unicode characters in json file.
140 
141     std::unordered_map<std::string, std::string> replacements = {
142         {"\\u0120", " "},
143         {"\\u010a", "\n"},
144     };
145 
146     for (const auto& replacement : replacements) {
147       size_t pos = 0;
148       // While loop through all instances of the substring in the string
149       while ((pos = key.find(replacement.first, pos)) != std::string::npos) {
150         key.replace(pos, replacement.first.length(), replacement.second);
151         pos += replacement.second.length();
152       }
153     }
154 
155     // remove duplicate backslashes
156     for (size_t idx = 0; idx < key.length(); idx++) {
157       if (key[idx] == '\\') {
158         key.erase(idx, 1);
159         if (key[idx] == '\\') {
160           // If there are two backslashes, keep the second one
161           idx += 1;
162         }
163       }
164     }
165 
166     return key;
167   }
parse_prompt(const std::string & prompt)168   std::vector<std::string> parse_prompt(const std::string& prompt) {
169     std::vector<std::string> result;
170     std::string word;
171     for (char c : prompt) {
172       if (c == ' ') {
173         if (!word.empty()) {
174           result.push_back(word);
175           word.clear();
176         }
177         word += c;
178       } else if (ispunct(c)) {
179         if (!word.empty()) {
180           result.push_back(word);
181           word.clear();
182         }
183         result.push_back(std::string(1, c));
184       } else {
185         word += c;
186       }
187     }
188     if (!word.empty()) {
189       result.push_back(word);
190     }
191     return result;
192   }
193 };
194