1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker // Adopted from https://github.com/sewenew/tokenizer
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker // @lint-ignore-every LICENSELINT
12*523fa7a6SAndroid Build Coastguard Worker /**************************************************************************
13*523fa7a6SAndroid Build Coastguard Worker Copyright (c) 2023 sewenew
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Worker Licensed under the Apache License, Version 2.0 (the "License");
16*523fa7a6SAndroid Build Coastguard Worker you may not use this file except in compliance with the License.
17*523fa7a6SAndroid Build Coastguard Worker You may obtain a copy of the License at
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Worker http://www.apache.org/licenses/LICENSE-2.0
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Worker Unless required by applicable law or agreed to in writing, software
22*523fa7a6SAndroid Build Coastguard Worker distributed under the License is distributed on an "AS IS" BASIS,
23*523fa7a6SAndroid Build Coastguard Worker WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24*523fa7a6SAndroid Build Coastguard Worker See the License for the specific language governing permissions and
25*523fa7a6SAndroid Build Coastguard Worker limitations under the License.
26*523fa7a6SAndroid Build Coastguard Worker *************************************************************************/
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/tokenizer/base64.h>
29*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/tokenizer/tiktoken.h>
30*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/result.h>
31*523fa7a6SAndroid Build Coastguard Worker #include <fstream>
32*523fa7a6SAndroid Build Coastguard Worker #include <limits>
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Error;
35*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Result;
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
38*523fa7a6SAndroid Build Coastguard Worker namespace extension {
39*523fa7a6SAndroid Build Coastguard Worker namespace llm {
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Worker // ------------------------------Util start------------------------------------
42*523fa7a6SAndroid Build Coastguard Worker
_max_size()43*523fa7a6SAndroid Build Coastguard Worker static uint64_t _max_size() {
44*523fa7a6SAndroid Build Coastguard Worker return std::numeric_limits<uint64_t>::max();
45*523fa7a6SAndroid Build Coastguard Worker }
46*523fa7a6SAndroid Build Coastguard Worker
_create_regex(const std::string & pattern)47*523fa7a6SAndroid Build Coastguard Worker static Re2UPtr _create_regex(const std::string& pattern) {
48*523fa7a6SAndroid Build Coastguard Worker assert(!pattern.empty());
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard Worker return std::make_unique<re2::RE2>("(" + pattern + ")");
51*523fa7a6SAndroid Build Coastguard Worker }
52*523fa7a6SAndroid Build Coastguard Worker
_build_special_token_regex(const Encoder & special_encoder)53*523fa7a6SAndroid Build Coastguard Worker static Re2UPtr _build_special_token_regex(const Encoder& special_encoder) {
54*523fa7a6SAndroid Build Coastguard Worker std::string special_pattern;
55*523fa7a6SAndroid Build Coastguard Worker for (const auto& ele : special_encoder) {
56*523fa7a6SAndroid Build Coastguard Worker if (!special_pattern.empty()) {
57*523fa7a6SAndroid Build Coastguard Worker special_pattern += "|";
58*523fa7a6SAndroid Build Coastguard Worker }
59*523fa7a6SAndroid Build Coastguard Worker special_pattern += re2::RE2::QuoteMeta(ele.first);
60*523fa7a6SAndroid Build Coastguard Worker }
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker if (special_pattern.empty()) {
63*523fa7a6SAndroid Build Coastguard Worker return nullptr;
64*523fa7a6SAndroid Build Coastguard Worker }
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Worker return _create_regex(special_pattern);
67*523fa7a6SAndroid Build Coastguard Worker }
68*523fa7a6SAndroid Build Coastguard Worker
_parse(const std::string & line)69*523fa7a6SAndroid Build Coastguard Worker static Result<std::pair<std::string, uint64_t>> _parse(
70*523fa7a6SAndroid Build Coastguard Worker const std::string& line) {
71*523fa7a6SAndroid Build Coastguard Worker // Tiktoken format
72*523fa7a6SAndroid Build Coastguard Worker // https://github.com/openai/tiktoken/blob/main/tiktoken/load.py#L140 <base64
73*523fa7a6SAndroid Build Coastguard Worker // encoded token str> <rank>
74*523fa7a6SAndroid Build Coastguard Worker auto pos = line.find(" ");
75*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
76*523fa7a6SAndroid Build Coastguard Worker pos != std::string::npos,
77*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
78*523fa7a6SAndroid Build Coastguard Worker "invalid tiktoken line: %s",
79*523fa7a6SAndroid Build Coastguard Worker line.c_str());
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker auto token = ET_UNWRAP(base64::decode({line.data(), pos}));
82*523fa7a6SAndroid Build Coastguard Worker uint64_t rank = 0;
83*523fa7a6SAndroid Build Coastguard Worker try {
84*523fa7a6SAndroid Build Coastguard Worker rank = std::stoul(line.substr(pos + 1));
85*523fa7a6SAndroid Build Coastguard Worker } catch (const std::exception&) {
86*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
87*523fa7a6SAndroid Build Coastguard Worker false, InvalidArgument, "invalid encoder rank: %s", line.c_str());
88*523fa7a6SAndroid Build Coastguard Worker }
89*523fa7a6SAndroid Build Coastguard Worker
90*523fa7a6SAndroid Build Coastguard Worker return std::pair{std::move(token), rank};
91*523fa7a6SAndroid Build Coastguard Worker }
92*523fa7a6SAndroid Build Coastguard Worker
_load_encoder(const std::string & path)93*523fa7a6SAndroid Build Coastguard Worker static Result<Encoder> _load_encoder(const std::string& path) {
94*523fa7a6SAndroid Build Coastguard Worker std::ifstream file(path);
95*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
96*523fa7a6SAndroid Build Coastguard Worker file, InvalidArgument, "failed to open encoder file: %s", path.c_str());
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Worker Encoder encoder;
99*523fa7a6SAndroid Build Coastguard Worker std::string line;
100*523fa7a6SAndroid Build Coastguard Worker while (std::getline(file, line)) {
101*523fa7a6SAndroid Build Coastguard Worker auto [token, rank] = ET_UNWRAP(_parse(line));
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
104*523fa7a6SAndroid Build Coastguard Worker encoder.emplace(std::move(token), rank).second,
105*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
106*523fa7a6SAndroid Build Coastguard Worker "duplicate item: %s",
107*523fa7a6SAndroid Build Coastguard Worker line.c_str());
108*523fa7a6SAndroid Build Coastguard Worker }
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard Worker return encoder;
111*523fa7a6SAndroid Build Coastguard Worker }
112*523fa7a6SAndroid Build Coastguard Worker
_build_decoder(const Encoder & encoder)113*523fa7a6SAndroid Build Coastguard Worker static Result<Decoder> _build_decoder(const Encoder& encoder) {
114*523fa7a6SAndroid Build Coastguard Worker Decoder decoder;
115*523fa7a6SAndroid Build Coastguard Worker for (const auto& [k, v] : encoder) {
116*523fa7a6SAndroid Build Coastguard Worker decoder.emplace(v, k);
117*523fa7a6SAndroid Build Coastguard Worker }
118*523fa7a6SAndroid Build Coastguard Worker
119*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OR_RETURN_ERROR(
120*523fa7a6SAndroid Build Coastguard Worker encoder.size() == decoder.size(),
121*523fa7a6SAndroid Build Coastguard Worker InvalidArgument,
122*523fa7a6SAndroid Build Coastguard Worker "duplicate items in encoder");
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Worker return decoder;
125*523fa7a6SAndroid Build Coastguard Worker }
126*523fa7a6SAndroid Build Coastguard Worker
_byte_pair_merge(const std::string & piece,const std::unordered_map<std::string,uint64_t> & ranks,std::function<uint64_t (uint64_t,uint64_t)> func)127*523fa7a6SAndroid Build Coastguard Worker static std::vector<uint64_t> _byte_pair_merge(
128*523fa7a6SAndroid Build Coastguard Worker const std::string& piece,
129*523fa7a6SAndroid Build Coastguard Worker const std::unordered_map<std::string, uint64_t>& ranks,
130*523fa7a6SAndroid Build Coastguard Worker std::function<uint64_t(uint64_t, uint64_t)> func) {
131*523fa7a6SAndroid Build Coastguard Worker // This is a vector of (start, rank).
132*523fa7a6SAndroid Build Coastguard Worker // The rank is of the byte pair starting at position start.
133*523fa7a6SAndroid Build Coastguard Worker // The rank of the last item in the vector is not a valid value.
134*523fa7a6SAndroid Build Coastguard Worker std::vector<std::pair<uint64_t, uint64_t>> parts;
135*523fa7a6SAndroid Build Coastguard Worker parts.reserve(piece.size() + 1);
136*523fa7a6SAndroid Build Coastguard Worker for (auto idx = 0U; idx < piece.size() + 1; ++idx) {
137*523fa7a6SAndroid Build Coastguard Worker parts.emplace_back(idx, _max_size());
138*523fa7a6SAndroid Build Coastguard Worker }
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker auto get_rank = [&piece, &ranks](
141*523fa7a6SAndroid Build Coastguard Worker const std::vector<std::pair<uint64_t, uint64_t>>& parts,
142*523fa7a6SAndroid Build Coastguard Worker uint64_t start_idx,
143*523fa7a6SAndroid Build Coastguard Worker uint64_t skip) -> std::optional<uint64_t> {
144*523fa7a6SAndroid Build Coastguard Worker if (start_idx + skip + 2 < parts.size()) {
145*523fa7a6SAndroid Build Coastguard Worker auto s = parts[start_idx].first;
146*523fa7a6SAndroid Build Coastguard Worker auto e = parts[start_idx + skip + 2].first;
147*523fa7a6SAndroid Build Coastguard Worker auto key = piece.substr(s, e - s);
148*523fa7a6SAndroid Build Coastguard Worker auto iter = ranks.find(key);
149*523fa7a6SAndroid Build Coastguard Worker if (iter != ranks.end()) {
150*523fa7a6SAndroid Build Coastguard Worker return iter->second;
151*523fa7a6SAndroid Build Coastguard Worker }
152*523fa7a6SAndroid Build Coastguard Worker }
153*523fa7a6SAndroid Build Coastguard Worker return std::nullopt;
154*523fa7a6SAndroid Build Coastguard Worker };
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Worker // We look up the ranks once in the beginning and iteratively update
157*523fa7a6SAndroid Build Coastguard Worker // them during each merge, which reduces the number of rank lookups.
158*523fa7a6SAndroid Build Coastguard Worker for (auto i = 0U; i < parts.size() - 2; ++i) {
159*523fa7a6SAndroid Build Coastguard Worker auto rank = get_rank(parts, i, 0);
160*523fa7a6SAndroid Build Coastguard Worker if (rank) {
161*523fa7a6SAndroid Build Coastguard Worker // usize::MAX is a sentinel value and cannot be a valid rank
162*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(*rank != _max_size(), "rank is too large");
163*523fa7a6SAndroid Build Coastguard Worker parts[i].second = *rank;
164*523fa7a6SAndroid Build Coastguard Worker }
165*523fa7a6SAndroid Build Coastguard Worker }
166*523fa7a6SAndroid Build Coastguard Worker
167*523fa7a6SAndroid Build Coastguard Worker // If you have n parts and m merges, this does O(mn) work.
168*523fa7a6SAndroid Build Coastguard Worker // We could do something with a heap and do O(m log n) work.
169*523fa7a6SAndroid Build Coastguard Worker // It is important to consider that n is often small (<100), and as such
170*523fa7a6SAndroid Build Coastguard Worker // the cache-locality benefits outweigh the algorithmic complexity downsides
171*523fa7a6SAndroid Build Coastguard Worker // of the `parts` vector data structure above.
172*523fa7a6SAndroid Build Coastguard Worker
173*523fa7a6SAndroid Build Coastguard Worker // Note that we hash bytes, not token pairs. As long as we train BPE the way
174*523fa7a6SAndroid Build Coastguard Worker // we currently do, this is equivalent. An easy way to break this would be
175*523fa7a6SAndroid Build Coastguard Worker // to decouple merge priority from token index or to prevent specific token
176*523fa7a6SAndroid Build Coastguard Worker // merges.
177*523fa7a6SAndroid Build Coastguard Worker while (true) {
178*523fa7a6SAndroid Build Coastguard Worker if (parts.size() == 1) {
179*523fa7a6SAndroid Build Coastguard Worker break;
180*523fa7a6SAndroid Build Coastguard Worker }
181*523fa7a6SAndroid Build Coastguard Worker
182*523fa7a6SAndroid Build Coastguard Worker // usize::MAX is a sentinel rank value allowing us to
183*523fa7a6SAndroid Build Coastguard Worker // take the min more quickly
184*523fa7a6SAndroid Build Coastguard Worker auto min_rank = std::make_pair<uint64_t, uint64_t>(_max_size(), 0);
185*523fa7a6SAndroid Build Coastguard Worker for (auto i = 0U; i < parts.size() - 1; ++i) {
186*523fa7a6SAndroid Build Coastguard Worker auto rank = parts[i].second;
187*523fa7a6SAndroid Build Coastguard Worker if (rank < min_rank.first) {
188*523fa7a6SAndroid Build Coastguard Worker min_rank.first = rank;
189*523fa7a6SAndroid Build Coastguard Worker min_rank.second = i;
190*523fa7a6SAndroid Build Coastguard Worker }
191*523fa7a6SAndroid Build Coastguard Worker }
192*523fa7a6SAndroid Build Coastguard Worker
193*523fa7a6SAndroid Build Coastguard Worker if (min_rank.first != _max_size()) {
194*523fa7a6SAndroid Build Coastguard Worker auto i = min_rank.second;
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker // NOTE: We are about to remove parts[i + 1]. We do not do it
197*523fa7a6SAndroid Build Coastguard Worker // yet because there are cache-locality benefits to updating
198*523fa7a6SAndroid Build Coastguard Worker // parts[i] and parts[i-1] before removing, which could thrash
199*523fa7a6SAndroid Build Coastguard Worker // the cache. Thus, we update the rank calculation by skipping over
200*523fa7a6SAndroid Build Coastguard Worker // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
201*523fa7a6SAndroid Build Coastguard Worker auto rank = get_rank(parts, i, 1);
202*523fa7a6SAndroid Build Coastguard Worker if (rank) {
203*523fa7a6SAndroid Build Coastguard Worker parts[i].second = *rank;
204*523fa7a6SAndroid Build Coastguard Worker } else {
205*523fa7a6SAndroid Build Coastguard Worker parts[i].second = _max_size();
206*523fa7a6SAndroid Build Coastguard Worker }
207*523fa7a6SAndroid Build Coastguard Worker if (i > 0) {
208*523fa7a6SAndroid Build Coastguard Worker rank = get_rank(parts, i - 1, 1);
209*523fa7a6SAndroid Build Coastguard Worker if (rank) {
210*523fa7a6SAndroid Build Coastguard Worker parts[i - 1].second = *rank;
211*523fa7a6SAndroid Build Coastguard Worker } else {
212*523fa7a6SAndroid Build Coastguard Worker parts[i - 1].second = _max_size();
213*523fa7a6SAndroid Build Coastguard Worker }
214*523fa7a6SAndroid Build Coastguard Worker }
215*523fa7a6SAndroid Build Coastguard Worker
216*523fa7a6SAndroid Build Coastguard Worker parts.erase(parts.begin() + (i + 1));
217*523fa7a6SAndroid Build Coastguard Worker } else {
218*523fa7a6SAndroid Build Coastguard Worker break;
219*523fa7a6SAndroid Build Coastguard Worker }
220*523fa7a6SAndroid Build Coastguard Worker }
221*523fa7a6SAndroid Build Coastguard Worker std::vector<uint64_t> out;
222*523fa7a6SAndroid Build Coastguard Worker out.reserve(parts.size() - 1);
223*523fa7a6SAndroid Build Coastguard Worker for (auto i = 0U; i < parts.size() - 1; ++i) {
224*523fa7a6SAndroid Build Coastguard Worker auto s = parts[i].first;
225*523fa7a6SAndroid Build Coastguard Worker auto e = parts[i + 1].first;
226*523fa7a6SAndroid Build Coastguard Worker out.push_back(func(s, e));
227*523fa7a6SAndroid Build Coastguard Worker }
228*523fa7a6SAndroid Build Coastguard Worker return out;
229*523fa7a6SAndroid Build Coastguard Worker }
230*523fa7a6SAndroid Build Coastguard Worker
_byte_pair_encode(const std::string & piece,const Encoder & encoder)231*523fa7a6SAndroid Build Coastguard Worker static std::vector<uint64_t> _byte_pair_encode(
232*523fa7a6SAndroid Build Coastguard Worker const std::string& piece,
233*523fa7a6SAndroid Build Coastguard Worker const Encoder& encoder) {
234*523fa7a6SAndroid Build Coastguard Worker if (piece.size() == 1) {
235*523fa7a6SAndroid Build Coastguard Worker auto iter = encoder.find(piece);
236*523fa7a6SAndroid Build Coastguard Worker if (iter != encoder.end()) {
237*523fa7a6SAndroid Build Coastguard Worker return std::vector<uint64_t>({iter->second});
238*523fa7a6SAndroid Build Coastguard Worker } else {
239*523fa7a6SAndroid Build Coastguard Worker // TODO: is it possible?
240*523fa7a6SAndroid Build Coastguard Worker return {};
241*523fa7a6SAndroid Build Coastguard Worker }
242*523fa7a6SAndroid Build Coastguard Worker }
243*523fa7a6SAndroid Build Coastguard Worker
244*523fa7a6SAndroid Build Coastguard Worker return _byte_pair_merge(
245*523fa7a6SAndroid Build Coastguard Worker piece, encoder, [&piece, &encoder](uint64_t start, uint64_t stop) {
246*523fa7a6SAndroid Build Coastguard Worker std::string key = piece.substr(start, stop - start);
247*523fa7a6SAndroid Build Coastguard Worker auto iter = encoder.find(key);
248*523fa7a6SAndroid Build Coastguard Worker if (iter != encoder.end()) {
249*523fa7a6SAndroid Build Coastguard Worker return iter->second;
250*523fa7a6SAndroid Build Coastguard Worker } else {
251*523fa7a6SAndroid Build Coastguard Worker // TODO: what if key does not exist? Should we return `unknown`?
252*523fa7a6SAndroid Build Coastguard Worker // assert(false); // ??
253*523fa7a6SAndroid Build Coastguard Worker return uint64_t(0);
254*523fa7a6SAndroid Build Coastguard Worker }
255*523fa7a6SAndroid Build Coastguard Worker });
256*523fa7a6SAndroid Build Coastguard Worker }
257*523fa7a6SAndroid Build Coastguard Worker // ------------------------------Util end------------------------------------
258*523fa7a6SAndroid Build Coastguard Worker // -------------------------private method start-------------------------------
259*523fa7a6SAndroid Build Coastguard Worker
260*523fa7a6SAndroid Build Coastguard Worker template <typename T>
261*523fa7a6SAndroid Build Coastguard Worker std::pair<std::optional<std::string>, re2::StringPiece>
_split_with_allowed_special_token(re2::StringPiece & input,const T & allowed_special) const262*523fa7a6SAndroid Build Coastguard Worker Tiktoken::_split_with_allowed_special_token(
263*523fa7a6SAndroid Build Coastguard Worker re2::StringPiece& input,
264*523fa7a6SAndroid Build Coastguard Worker const T& allowed_special) const {
265*523fa7a6SAndroid Build Coastguard Worker if (!_special_token_regex) {
266*523fa7a6SAndroid Build Coastguard Worker return std::make_pair(std::nullopt, input);
267*523fa7a6SAndroid Build Coastguard Worker }
268*523fa7a6SAndroid Build Coastguard Worker
269*523fa7a6SAndroid Build Coastguard Worker #if __cplusplus >= 202002L
270*523fa7a6SAndroid Build Coastguard Worker auto start = input.begin();
271*523fa7a6SAndroid Build Coastguard Worker #else
272*523fa7a6SAndroid Build Coastguard Worker const char* start = input.data();
273*523fa7a6SAndroid Build Coastguard Worker #endif
274*523fa7a6SAndroid Build Coastguard Worker std::string special;
275*523fa7a6SAndroid Build Coastguard Worker while (true) {
276*523fa7a6SAndroid Build Coastguard Worker if (!re2::RE2::FindAndConsume(&input, *_special_token_regex, &special)) {
277*523fa7a6SAndroid Build Coastguard Worker // No special token.
278*523fa7a6SAndroid Build Coastguard Worker break;
279*523fa7a6SAndroid Build Coastguard Worker }
280*523fa7a6SAndroid Build Coastguard Worker
281*523fa7a6SAndroid Build Coastguard Worker if (allowed_special.count(special) == 1) {
282*523fa7a6SAndroid Build Coastguard Worker // Found an allowed special token, split the text with it.
283*523fa7a6SAndroid Build Coastguard Worker #if __cplusplus >= 202002L
284*523fa7a6SAndroid Build Coastguard Worker return std::make_pair(
285*523fa7a6SAndroid Build Coastguard Worker special,
286*523fa7a6SAndroid Build Coastguard Worker re2::StringPiece(start, input.begin() - start - special.size()));
287*523fa7a6SAndroid Build Coastguard Worker #else
288*523fa7a6SAndroid Build Coastguard Worker return std::make_pair(
289*523fa7a6SAndroid Build Coastguard Worker special,
290*523fa7a6SAndroid Build Coastguard Worker re2::StringPiece(start, (input.data() - start) - special.size()));
291*523fa7a6SAndroid Build Coastguard Worker #endif
292*523fa7a6SAndroid Build Coastguard Worker } // else try to find the next special token
293*523fa7a6SAndroid Build Coastguard Worker }
294*523fa7a6SAndroid Build Coastguard Worker
295*523fa7a6SAndroid Build Coastguard Worker return std::make_pair(std::nullopt, input);
296*523fa7a6SAndroid Build Coastguard Worker }
297*523fa7a6SAndroid Build Coastguard Worker
_encode(re2::StringPiece & input,std::vector<uint64_t> & ret,uint64_t & last_piece_token_len) const298*523fa7a6SAndroid Build Coastguard Worker void Tiktoken::_encode(
299*523fa7a6SAndroid Build Coastguard Worker re2::StringPiece& input,
300*523fa7a6SAndroid Build Coastguard Worker std::vector<uint64_t>& ret,
301*523fa7a6SAndroid Build Coastguard Worker uint64_t& last_piece_token_len) const {
302*523fa7a6SAndroid Build Coastguard Worker std::string piece;
303*523fa7a6SAndroid Build Coastguard Worker assert(_regex);
304*523fa7a6SAndroid Build Coastguard Worker while (re2::RE2::FindAndConsume(&input, *_regex, &piece)) {
305*523fa7a6SAndroid Build Coastguard Worker auto iter = _encoder.find(piece);
306*523fa7a6SAndroid Build Coastguard Worker if (iter != _encoder.end()) {
307*523fa7a6SAndroid Build Coastguard Worker last_piece_token_len = 1;
308*523fa7a6SAndroid Build Coastguard Worker ret.push_back(iter->second);
309*523fa7a6SAndroid Build Coastguard Worker continue;
310*523fa7a6SAndroid Build Coastguard Worker }
311*523fa7a6SAndroid Build Coastguard Worker auto tokens = _byte_pair_encode(piece, _encoder);
312*523fa7a6SAndroid Build Coastguard Worker last_piece_token_len = tokens.size();
313*523fa7a6SAndroid Build Coastguard Worker ret.insert(ret.end(), tokens.begin(), tokens.end());
314*523fa7a6SAndroid Build Coastguard Worker }
315*523fa7a6SAndroid Build Coastguard Worker }
316*523fa7a6SAndroid Build Coastguard Worker
317*523fa7a6SAndroid Build Coastguard Worker template <typename T>
_encode_with_special_token(const std::string & text,const T & allowed_special) const318*523fa7a6SAndroid Build Coastguard Worker std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(
319*523fa7a6SAndroid Build Coastguard Worker const std::string& text,
320*523fa7a6SAndroid Build Coastguard Worker const T& allowed_special) const {
321*523fa7a6SAndroid Build Coastguard Worker std::vector<uint64_t> tokens;
322*523fa7a6SAndroid Build Coastguard Worker uint64_t last_piece_token_len = 0;
323*523fa7a6SAndroid Build Coastguard Worker re2::StringPiece input(text);
324*523fa7a6SAndroid Build Coastguard Worker while (true) {
325*523fa7a6SAndroid Build Coastguard Worker auto [special, sub_input] =
326*523fa7a6SAndroid Build Coastguard Worker _split_with_allowed_special_token(input, allowed_special);
327*523fa7a6SAndroid Build Coastguard Worker
328*523fa7a6SAndroid Build Coastguard Worker _encode(sub_input, tokens, last_piece_token_len);
329*523fa7a6SAndroid Build Coastguard Worker
330*523fa7a6SAndroid Build Coastguard Worker if (special) {
331*523fa7a6SAndroid Build Coastguard Worker uint64_t token = 0;
332*523fa7a6SAndroid Build Coastguard Worker try {
333*523fa7a6SAndroid Build Coastguard Worker token = _special_token_encoder.at(*special);
334*523fa7a6SAndroid Build Coastguard Worker } catch (const std::out_of_range&) {
335*523fa7a6SAndroid Build Coastguard Worker // Should never go here, since special pattern includes all special
336*523fa7a6SAndroid Build Coastguard Worker // chars.
337*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(false, "unknown special token: %s", special->c_str());
338*523fa7a6SAndroid Build Coastguard Worker }
339*523fa7a6SAndroid Build Coastguard Worker
340*523fa7a6SAndroid Build Coastguard Worker tokens.push_back(token);
341*523fa7a6SAndroid Build Coastguard Worker last_piece_token_len = 0;
342*523fa7a6SAndroid Build Coastguard Worker } else {
343*523fa7a6SAndroid Build Coastguard Worker break;
344*523fa7a6SAndroid Build Coastguard Worker }
345*523fa7a6SAndroid Build Coastguard Worker }
346*523fa7a6SAndroid Build Coastguard Worker
347*523fa7a6SAndroid Build Coastguard Worker // last_piece_token_len is how many tokens came from the last regex split.
348*523fa7a6SAndroid Build Coastguard Worker // This is used for determining unstable tokens, since you can't merge
349*523fa7a6SAndroid Build Coastguard Worker // across (stable) regex splits
350*523fa7a6SAndroid Build Coastguard Worker return std::make_pair(tokens, last_piece_token_len);
351*523fa7a6SAndroid Build Coastguard Worker }
352*523fa7a6SAndroid Build Coastguard Worker
_build_special_token_encoder(ssize_t num_base_tokens) const353*523fa7a6SAndroid Build Coastguard Worker Encoder Tiktoken::_build_special_token_encoder(ssize_t num_base_tokens) const {
354*523fa7a6SAndroid Build Coastguard Worker Encoder special_token_encoder;
355*523fa7a6SAndroid Build Coastguard Worker for (ssize_t i = 0; i < _special_tokens->size(); ++i) {
356*523fa7a6SAndroid Build Coastguard Worker special_token_encoder.emplace(_special_tokens->at(i), num_base_tokens + i);
357*523fa7a6SAndroid Build Coastguard Worker }
358*523fa7a6SAndroid Build Coastguard Worker return special_token_encoder;
359*523fa7a6SAndroid Build Coastguard Worker }
360*523fa7a6SAndroid Build Coastguard Worker
361*523fa7a6SAndroid Build Coastguard Worker // -------------------------private method end-------------------------------
362*523fa7a6SAndroid Build Coastguard Worker // -------------------------public method start-------------------------------
363*523fa7a6SAndroid Build Coastguard Worker
Tiktoken(std::unique_ptr<std::vector<std::string>> special_tokens,size_t bos_token_index,size_t eos_token_index)364*523fa7a6SAndroid Build Coastguard Worker Tiktoken::Tiktoken(
365*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<std::vector<std::string>> special_tokens,
366*523fa7a6SAndroid Build Coastguard Worker size_t bos_token_index,
367*523fa7a6SAndroid Build Coastguard Worker size_t eos_token_index)
368*523fa7a6SAndroid Build Coastguard Worker : Tokenizer(),
369*523fa7a6SAndroid Build Coastguard Worker _special_tokens(std::move(special_tokens)),
370*523fa7a6SAndroid Build Coastguard Worker _bos_token_index(bos_token_index),
371*523fa7a6SAndroid Build Coastguard Worker _eos_token_index(eos_token_index) {
372*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
373*523fa7a6SAndroid Build Coastguard Worker _bos_token_index < _special_tokens->size(),
374*523fa7a6SAndroid Build Coastguard Worker "invalid bos_token_index %zu",
375*523fa7a6SAndroid Build Coastguard Worker _bos_token_index);
376*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(
377*523fa7a6SAndroid Build Coastguard Worker _eos_token_index < _special_tokens->size(),
378*523fa7a6SAndroid Build Coastguard Worker "invalid eos_token_index %zu",
379*523fa7a6SAndroid Build Coastguard Worker _eos_token_index);
380*523fa7a6SAndroid Build Coastguard Worker }
381*523fa7a6SAndroid Build Coastguard Worker
load(const std::string & path)382*523fa7a6SAndroid Build Coastguard Worker Error Tiktoken::load(const std::string& path) {
383*523fa7a6SAndroid Build Coastguard Worker _encoder = ET_UNWRAP(_load_encoder(path));
384*523fa7a6SAndroid Build Coastguard Worker _special_token_encoder = _build_special_token_encoder(_encoder.size());
385*523fa7a6SAndroid Build Coastguard Worker
386*523fa7a6SAndroid Build Coastguard Worker _decoder = ET_UNWRAP(_build_decoder(_encoder));
387*523fa7a6SAndroid Build Coastguard Worker _special_token_decoder = ET_UNWRAP(_build_decoder(_special_token_encoder));
388*523fa7a6SAndroid Build Coastguard Worker
389*523fa7a6SAndroid Build Coastguard Worker _regex = _create_regex(_pattern);
390*523fa7a6SAndroid Build Coastguard Worker // Warmup re2 as it is slow on the first run, void the return value as it's
391*523fa7a6SAndroid Build Coastguard Worker // not needed Refer to
392*523fa7a6SAndroid Build Coastguard Worker // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
393*523fa7a6SAndroid Build Coastguard Worker (void)_regex->ReverseProgramSize();
394*523fa7a6SAndroid Build Coastguard Worker
395*523fa7a6SAndroid Build Coastguard Worker _special_token_regex = _build_special_token_regex(_special_token_encoder);
396*523fa7a6SAndroid Build Coastguard Worker // Same as above, warm up re2
397*523fa7a6SAndroid Build Coastguard Worker (void)_special_token_regex->ReverseProgramSize();
398*523fa7a6SAndroid Build Coastguard Worker
399*523fa7a6SAndroid Build Coastguard Worker // initialize vocab_size, bos_tok, eos_tok
400*523fa7a6SAndroid Build Coastguard Worker vocab_size_ = _encoder.size() + _special_token_encoder.size();
401*523fa7a6SAndroid Build Coastguard Worker bos_tok_ = _special_token_encoder.at(_special_tokens->at(_bos_token_index));
402*523fa7a6SAndroid Build Coastguard Worker eos_tok_ = _special_token_encoder.at(_special_tokens->at(_eos_token_index));
403*523fa7a6SAndroid Build Coastguard Worker
404*523fa7a6SAndroid Build Coastguard Worker initialized_ = true;
405*523fa7a6SAndroid Build Coastguard Worker return Error::Ok;
406*523fa7a6SAndroid Build Coastguard Worker }
407*523fa7a6SAndroid Build Coastguard Worker
408*523fa7a6SAndroid Build Coastguard Worker Result<std::vector<uint64_t>>
encode(const std::string & text,int8_t bos,int8_t eos) const409*523fa7a6SAndroid Build Coastguard Worker Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) const {
410*523fa7a6SAndroid Build Coastguard Worker if (!initialized_) {
411*523fa7a6SAndroid Build Coastguard Worker return Error::NotSupported;
412*523fa7a6SAndroid Build Coastguard Worker }
413*523fa7a6SAndroid Build Coastguard Worker auto res = _encode_with_special_token(text, _special_token_encoder).first;
414*523fa7a6SAndroid Build Coastguard Worker for (auto i = 0; i < bos; ++i) {
415*523fa7a6SAndroid Build Coastguard Worker res.insert(res.begin(), bos_tok_);
416*523fa7a6SAndroid Build Coastguard Worker }
417*523fa7a6SAndroid Build Coastguard Worker for (auto i = 0; i < eos; ++i) {
418*523fa7a6SAndroid Build Coastguard Worker res.push_back(eos_tok_);
419*523fa7a6SAndroid Build Coastguard Worker }
420*523fa7a6SAndroid Build Coastguard Worker return Result<std::vector<uint64_t>>(std::move(res));
421*523fa7a6SAndroid Build Coastguard Worker }
422*523fa7a6SAndroid Build Coastguard Worker
decode(uint64_t prev,uint64_t cur) const423*523fa7a6SAndroid Build Coastguard Worker Result<std::string> Tiktoken::decode(uint64_t prev, uint64_t cur) const {
424*523fa7a6SAndroid Build Coastguard Worker (void)prev;
425*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(cur));
426*523fa7a6SAndroid Build Coastguard Worker std::string ret;
427*523fa7a6SAndroid Build Coastguard Worker
428*523fa7a6SAndroid Build Coastguard Worker std::string token_bytes;
429*523fa7a6SAndroid Build Coastguard Worker auto iter = _decoder.find(cur);
430*523fa7a6SAndroid Build Coastguard Worker if (iter != _decoder.end()) {
431*523fa7a6SAndroid Build Coastguard Worker token_bytes = iter->second;
432*523fa7a6SAndroid Build Coastguard Worker } else {
433*523fa7a6SAndroid Build Coastguard Worker iter = _special_token_decoder.find(cur);
434*523fa7a6SAndroid Build Coastguard Worker if (iter != _special_token_decoder.end()) {
435*523fa7a6SAndroid Build Coastguard Worker token_bytes = iter->second;
436*523fa7a6SAndroid Build Coastguard Worker } else {
437*523fa7a6SAndroid Build Coastguard Worker ET_CHECK_MSG(false, "unknown token: %" PRIu64, cur);
438*523fa7a6SAndroid Build Coastguard Worker }
439*523fa7a6SAndroid Build Coastguard Worker }
440*523fa7a6SAndroid Build Coastguard Worker ret += token_bytes;
441*523fa7a6SAndroid Build Coastguard Worker
442*523fa7a6SAndroid Build Coastguard Worker return ret;
443*523fa7a6SAndroid Build Coastguard Worker }
444*523fa7a6SAndroid Build Coastguard Worker // -------------------------public method end-------------------------------
445*523fa7a6SAndroid Build Coastguard Worker
446*523fa7a6SAndroid Build Coastguard Worker } // namespace llm
447*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
448*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
449