xref: /aosp_15_r20/external/executorch/extension/llm/tokenizer/tiktoken.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker // Adopted from https://github.com/sewenew/tokenizer
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker // @lint-ignore-every LICENSELINT
12*523fa7a6SAndroid Build Coastguard Worker /**************************************************************************
13*523fa7a6SAndroid Build Coastguard Worker    Copyright (c) 2023 sewenew
14*523fa7a6SAndroid Build Coastguard Worker 
15*523fa7a6SAndroid Build Coastguard Worker    Licensed under the Apache License, Version 2.0 (the "License");
16*523fa7a6SAndroid Build Coastguard Worker    you may not use this file except in compliance with the License.
17*523fa7a6SAndroid Build Coastguard Worker    You may obtain a copy of the License at
18*523fa7a6SAndroid Build Coastguard Worker 
19*523fa7a6SAndroid Build Coastguard Worker        http://www.apache.org/licenses/LICENSE-2.0
20*523fa7a6SAndroid Build Coastguard Worker 
21*523fa7a6SAndroid Build Coastguard Worker    Unless required by applicable law or agreed to in writing, software
22*523fa7a6SAndroid Build Coastguard Worker    distributed under the License is distributed on an "AS IS" BASIS,
23*523fa7a6SAndroid Build Coastguard Worker    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24*523fa7a6SAndroid Build Coastguard Worker    See the License for the specific language governing permissions and
25*523fa7a6SAndroid Build Coastguard Worker    limitations under the License.
26*523fa7a6SAndroid Build Coastguard Worker  *************************************************************************/
27*523fa7a6SAndroid Build Coastguard Worker 
28*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/tokenizer/base64.h>
29*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/tokenizer/tiktoken.h>
30*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/result.h>
31*523fa7a6SAndroid Build Coastguard Worker #include <fstream>
32*523fa7a6SAndroid Build Coastguard Worker #include <limits>
33*523fa7a6SAndroid Build Coastguard Worker 
34*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Error;
35*523fa7a6SAndroid Build Coastguard Worker using ::executorch::runtime::Result;
36*523fa7a6SAndroid Build Coastguard Worker 
37*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
38*523fa7a6SAndroid Build Coastguard Worker namespace extension {
39*523fa7a6SAndroid Build Coastguard Worker namespace llm {
40*523fa7a6SAndroid Build Coastguard Worker 
41*523fa7a6SAndroid Build Coastguard Worker // ------------------------------Util start------------------------------------
42*523fa7a6SAndroid Build Coastguard Worker 
_max_size()43*523fa7a6SAndroid Build Coastguard Worker static uint64_t _max_size() {
44*523fa7a6SAndroid Build Coastguard Worker   return std::numeric_limits<uint64_t>::max();
45*523fa7a6SAndroid Build Coastguard Worker }
46*523fa7a6SAndroid Build Coastguard Worker 
_create_regex(const std::string & pattern)47*523fa7a6SAndroid Build Coastguard Worker static Re2UPtr _create_regex(const std::string& pattern) {
48*523fa7a6SAndroid Build Coastguard Worker   assert(!pattern.empty());
49*523fa7a6SAndroid Build Coastguard Worker 
50*523fa7a6SAndroid Build Coastguard Worker   return std::make_unique<re2::RE2>("(" + pattern + ")");
51*523fa7a6SAndroid Build Coastguard Worker }
52*523fa7a6SAndroid Build Coastguard Worker 
_build_special_token_regex(const Encoder & special_encoder)53*523fa7a6SAndroid Build Coastguard Worker static Re2UPtr _build_special_token_regex(const Encoder& special_encoder) {
54*523fa7a6SAndroid Build Coastguard Worker   std::string special_pattern;
55*523fa7a6SAndroid Build Coastguard Worker   for (const auto& ele : special_encoder) {
56*523fa7a6SAndroid Build Coastguard Worker     if (!special_pattern.empty()) {
57*523fa7a6SAndroid Build Coastguard Worker       special_pattern += "|";
58*523fa7a6SAndroid Build Coastguard Worker     }
59*523fa7a6SAndroid Build Coastguard Worker     special_pattern += re2::RE2::QuoteMeta(ele.first);
60*523fa7a6SAndroid Build Coastguard Worker   }
61*523fa7a6SAndroid Build Coastguard Worker 
62*523fa7a6SAndroid Build Coastguard Worker   if (special_pattern.empty()) {
63*523fa7a6SAndroid Build Coastguard Worker     return nullptr;
64*523fa7a6SAndroid Build Coastguard Worker   }
65*523fa7a6SAndroid Build Coastguard Worker 
66*523fa7a6SAndroid Build Coastguard Worker   return _create_regex(special_pattern);
67*523fa7a6SAndroid Build Coastguard Worker }
68*523fa7a6SAndroid Build Coastguard Worker 
_parse(const std::string & line)69*523fa7a6SAndroid Build Coastguard Worker static Result<std::pair<std::string, uint64_t>> _parse(
70*523fa7a6SAndroid Build Coastguard Worker     const std::string& line) {
71*523fa7a6SAndroid Build Coastguard Worker   // Tiktoken format
72*523fa7a6SAndroid Build Coastguard Worker   // https://github.com/openai/tiktoken/blob/main/tiktoken/load.py#L140 <base64
73*523fa7a6SAndroid Build Coastguard Worker   // encoded token str> <rank>
74*523fa7a6SAndroid Build Coastguard Worker   auto pos = line.find(" ");
75*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
76*523fa7a6SAndroid Build Coastguard Worker       pos != std::string::npos,
77*523fa7a6SAndroid Build Coastguard Worker       InvalidArgument,
78*523fa7a6SAndroid Build Coastguard Worker       "invalid tiktoken line: %s",
79*523fa7a6SAndroid Build Coastguard Worker       line.c_str());
80*523fa7a6SAndroid Build Coastguard Worker 
81*523fa7a6SAndroid Build Coastguard Worker   auto token = ET_UNWRAP(base64::decode({line.data(), pos}));
82*523fa7a6SAndroid Build Coastguard Worker   uint64_t rank = 0;
83*523fa7a6SAndroid Build Coastguard Worker   try {
84*523fa7a6SAndroid Build Coastguard Worker     rank = std::stoul(line.substr(pos + 1));
85*523fa7a6SAndroid Build Coastguard Worker   } catch (const std::exception&) {
86*523fa7a6SAndroid Build Coastguard Worker     ET_CHECK_OR_RETURN_ERROR(
87*523fa7a6SAndroid Build Coastguard Worker         false, InvalidArgument, "invalid encoder rank: %s", line.c_str());
88*523fa7a6SAndroid Build Coastguard Worker   }
89*523fa7a6SAndroid Build Coastguard Worker 
90*523fa7a6SAndroid Build Coastguard Worker   return std::pair{std::move(token), rank};
91*523fa7a6SAndroid Build Coastguard Worker }
92*523fa7a6SAndroid Build Coastguard Worker 
_load_encoder(const std::string & path)93*523fa7a6SAndroid Build Coastguard Worker static Result<Encoder> _load_encoder(const std::string& path) {
94*523fa7a6SAndroid Build Coastguard Worker   std::ifstream file(path);
95*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
96*523fa7a6SAndroid Build Coastguard Worker       file, InvalidArgument, "failed to open encoder file: %s", path.c_str());
97*523fa7a6SAndroid Build Coastguard Worker 
98*523fa7a6SAndroid Build Coastguard Worker   Encoder encoder;
99*523fa7a6SAndroid Build Coastguard Worker   std::string line;
100*523fa7a6SAndroid Build Coastguard Worker   while (std::getline(file, line)) {
101*523fa7a6SAndroid Build Coastguard Worker     auto [token, rank] = ET_UNWRAP(_parse(line));
102*523fa7a6SAndroid Build Coastguard Worker 
103*523fa7a6SAndroid Build Coastguard Worker     ET_CHECK_OR_RETURN_ERROR(
104*523fa7a6SAndroid Build Coastguard Worker         encoder.emplace(std::move(token), rank).second,
105*523fa7a6SAndroid Build Coastguard Worker         InvalidArgument,
106*523fa7a6SAndroid Build Coastguard Worker         "duplicate item: %s",
107*523fa7a6SAndroid Build Coastguard Worker         line.c_str());
108*523fa7a6SAndroid Build Coastguard Worker   }
109*523fa7a6SAndroid Build Coastguard Worker 
110*523fa7a6SAndroid Build Coastguard Worker   return encoder;
111*523fa7a6SAndroid Build Coastguard Worker }
112*523fa7a6SAndroid Build Coastguard Worker 
_build_decoder(const Encoder & encoder)113*523fa7a6SAndroid Build Coastguard Worker static Result<Decoder> _build_decoder(const Encoder& encoder) {
114*523fa7a6SAndroid Build Coastguard Worker   Decoder decoder;
115*523fa7a6SAndroid Build Coastguard Worker   for (const auto& [k, v] : encoder) {
116*523fa7a6SAndroid Build Coastguard Worker     decoder.emplace(v, k);
117*523fa7a6SAndroid Build Coastguard Worker   }
118*523fa7a6SAndroid Build Coastguard Worker 
119*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OR_RETURN_ERROR(
120*523fa7a6SAndroid Build Coastguard Worker       encoder.size() == decoder.size(),
121*523fa7a6SAndroid Build Coastguard Worker       InvalidArgument,
122*523fa7a6SAndroid Build Coastguard Worker       "duplicate items in encoder");
123*523fa7a6SAndroid Build Coastguard Worker 
124*523fa7a6SAndroid Build Coastguard Worker   return decoder;
125*523fa7a6SAndroid Build Coastguard Worker }
126*523fa7a6SAndroid Build Coastguard Worker 
_byte_pair_merge(const std::string & piece,const std::unordered_map<std::string,uint64_t> & ranks,std::function<uint64_t (uint64_t,uint64_t)> func)127*523fa7a6SAndroid Build Coastguard Worker static std::vector<uint64_t> _byte_pair_merge(
128*523fa7a6SAndroid Build Coastguard Worker     const std::string& piece,
129*523fa7a6SAndroid Build Coastguard Worker     const std::unordered_map<std::string, uint64_t>& ranks,
130*523fa7a6SAndroid Build Coastguard Worker     std::function<uint64_t(uint64_t, uint64_t)> func) {
131*523fa7a6SAndroid Build Coastguard Worker   // This is a vector of (start, rank).
132*523fa7a6SAndroid Build Coastguard Worker   // The rank is of the byte pair starting at position start.
133*523fa7a6SAndroid Build Coastguard Worker   // The rank of the last item in the vector is not a valid value.
134*523fa7a6SAndroid Build Coastguard Worker   std::vector<std::pair<uint64_t, uint64_t>> parts;
135*523fa7a6SAndroid Build Coastguard Worker   parts.reserve(piece.size() + 1);
136*523fa7a6SAndroid Build Coastguard Worker   for (auto idx = 0U; idx < piece.size() + 1; ++idx) {
137*523fa7a6SAndroid Build Coastguard Worker     parts.emplace_back(idx, _max_size());
138*523fa7a6SAndroid Build Coastguard Worker   }
139*523fa7a6SAndroid Build Coastguard Worker 
140*523fa7a6SAndroid Build Coastguard Worker   auto get_rank = [&piece, &ranks](
141*523fa7a6SAndroid Build Coastguard Worker                       const std::vector<std::pair<uint64_t, uint64_t>>& parts,
142*523fa7a6SAndroid Build Coastguard Worker                       uint64_t start_idx,
143*523fa7a6SAndroid Build Coastguard Worker                       uint64_t skip) -> std::optional<uint64_t> {
144*523fa7a6SAndroid Build Coastguard Worker     if (start_idx + skip + 2 < parts.size()) {
145*523fa7a6SAndroid Build Coastguard Worker       auto s = parts[start_idx].first;
146*523fa7a6SAndroid Build Coastguard Worker       auto e = parts[start_idx + skip + 2].first;
147*523fa7a6SAndroid Build Coastguard Worker       auto key = piece.substr(s, e - s);
148*523fa7a6SAndroid Build Coastguard Worker       auto iter = ranks.find(key);
149*523fa7a6SAndroid Build Coastguard Worker       if (iter != ranks.end()) {
150*523fa7a6SAndroid Build Coastguard Worker         return iter->second;
151*523fa7a6SAndroid Build Coastguard Worker       }
152*523fa7a6SAndroid Build Coastguard Worker     }
153*523fa7a6SAndroid Build Coastguard Worker     return std::nullopt;
154*523fa7a6SAndroid Build Coastguard Worker   };
155*523fa7a6SAndroid Build Coastguard Worker 
156*523fa7a6SAndroid Build Coastguard Worker   // We look up the ranks once in the beginning and iteratively update
157*523fa7a6SAndroid Build Coastguard Worker   // them during each merge, which reduces the number of rank lookups.
158*523fa7a6SAndroid Build Coastguard Worker   for (auto i = 0U; i < parts.size() - 2; ++i) {
159*523fa7a6SAndroid Build Coastguard Worker     auto rank = get_rank(parts, i, 0);
160*523fa7a6SAndroid Build Coastguard Worker     if (rank) {
161*523fa7a6SAndroid Build Coastguard Worker       // usize::MAX is a sentinel value and cannot be a valid rank
162*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_MSG(*rank != _max_size(), "rank is too large");
163*523fa7a6SAndroid Build Coastguard Worker       parts[i].second = *rank;
164*523fa7a6SAndroid Build Coastguard Worker     }
165*523fa7a6SAndroid Build Coastguard Worker   }
166*523fa7a6SAndroid Build Coastguard Worker 
167*523fa7a6SAndroid Build Coastguard Worker   // If you have n parts and m merges, this does O(mn) work.
168*523fa7a6SAndroid Build Coastguard Worker   // We could do something with a heap and do O(m log n) work.
169*523fa7a6SAndroid Build Coastguard Worker   // It is important to consider that n is often small (<100), and as such
170*523fa7a6SAndroid Build Coastguard Worker   // the cache-locality benefits outweigh the algorithmic complexity downsides
171*523fa7a6SAndroid Build Coastguard Worker   // of the `parts` vector data structure above.
172*523fa7a6SAndroid Build Coastguard Worker 
173*523fa7a6SAndroid Build Coastguard Worker   // Note that we hash bytes, not token pairs. As long as we train BPE the way
174*523fa7a6SAndroid Build Coastguard Worker   // we currently do, this is equivalent. An easy way to break this would be
175*523fa7a6SAndroid Build Coastguard Worker   // to decouple merge priority from token index or to prevent specific token
176*523fa7a6SAndroid Build Coastguard Worker   // merges.
177*523fa7a6SAndroid Build Coastguard Worker   while (true) {
178*523fa7a6SAndroid Build Coastguard Worker     if (parts.size() == 1) {
179*523fa7a6SAndroid Build Coastguard Worker       break;
180*523fa7a6SAndroid Build Coastguard Worker     }
181*523fa7a6SAndroid Build Coastguard Worker 
182*523fa7a6SAndroid Build Coastguard Worker     // usize::MAX is a sentinel rank value allowing us to
183*523fa7a6SAndroid Build Coastguard Worker     // take the min more quickly
184*523fa7a6SAndroid Build Coastguard Worker     auto min_rank = std::make_pair<uint64_t, uint64_t>(_max_size(), 0);
185*523fa7a6SAndroid Build Coastguard Worker     for (auto i = 0U; i < parts.size() - 1; ++i) {
186*523fa7a6SAndroid Build Coastguard Worker       auto rank = parts[i].second;
187*523fa7a6SAndroid Build Coastguard Worker       if (rank < min_rank.first) {
188*523fa7a6SAndroid Build Coastguard Worker         min_rank.first = rank;
189*523fa7a6SAndroid Build Coastguard Worker         min_rank.second = i;
190*523fa7a6SAndroid Build Coastguard Worker       }
191*523fa7a6SAndroid Build Coastguard Worker     }
192*523fa7a6SAndroid Build Coastguard Worker 
193*523fa7a6SAndroid Build Coastguard Worker     if (min_rank.first != _max_size()) {
194*523fa7a6SAndroid Build Coastguard Worker       auto i = min_rank.second;
195*523fa7a6SAndroid Build Coastguard Worker 
196*523fa7a6SAndroid Build Coastguard Worker       // NOTE: We are about to remove parts[i + 1]. We do not do it
197*523fa7a6SAndroid Build Coastguard Worker       // yet because there are cache-locality benefits to updating
198*523fa7a6SAndroid Build Coastguard Worker       // parts[i] and parts[i-1] before removing, which could thrash
199*523fa7a6SAndroid Build Coastguard Worker       // the cache. Thus, we update the rank calculation by skipping over
200*523fa7a6SAndroid Build Coastguard Worker       // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
201*523fa7a6SAndroid Build Coastguard Worker       auto rank = get_rank(parts, i, 1);
202*523fa7a6SAndroid Build Coastguard Worker       if (rank) {
203*523fa7a6SAndroid Build Coastguard Worker         parts[i].second = *rank;
204*523fa7a6SAndroid Build Coastguard Worker       } else {
205*523fa7a6SAndroid Build Coastguard Worker         parts[i].second = _max_size();
206*523fa7a6SAndroid Build Coastguard Worker       }
207*523fa7a6SAndroid Build Coastguard Worker       if (i > 0) {
208*523fa7a6SAndroid Build Coastguard Worker         rank = get_rank(parts, i - 1, 1);
209*523fa7a6SAndroid Build Coastguard Worker         if (rank) {
210*523fa7a6SAndroid Build Coastguard Worker           parts[i - 1].second = *rank;
211*523fa7a6SAndroid Build Coastguard Worker         } else {
212*523fa7a6SAndroid Build Coastguard Worker           parts[i - 1].second = _max_size();
213*523fa7a6SAndroid Build Coastguard Worker         }
214*523fa7a6SAndroid Build Coastguard Worker       }
215*523fa7a6SAndroid Build Coastguard Worker 
216*523fa7a6SAndroid Build Coastguard Worker       parts.erase(parts.begin() + (i + 1));
217*523fa7a6SAndroid Build Coastguard Worker     } else {
218*523fa7a6SAndroid Build Coastguard Worker       break;
219*523fa7a6SAndroid Build Coastguard Worker     }
220*523fa7a6SAndroid Build Coastguard Worker   }
221*523fa7a6SAndroid Build Coastguard Worker   std::vector<uint64_t> out;
222*523fa7a6SAndroid Build Coastguard Worker   out.reserve(parts.size() - 1);
223*523fa7a6SAndroid Build Coastguard Worker   for (auto i = 0U; i < parts.size() - 1; ++i) {
224*523fa7a6SAndroid Build Coastguard Worker     auto s = parts[i].first;
225*523fa7a6SAndroid Build Coastguard Worker     auto e = parts[i + 1].first;
226*523fa7a6SAndroid Build Coastguard Worker     out.push_back(func(s, e));
227*523fa7a6SAndroid Build Coastguard Worker   }
228*523fa7a6SAndroid Build Coastguard Worker   return out;
229*523fa7a6SAndroid Build Coastguard Worker }
230*523fa7a6SAndroid Build Coastguard Worker 
_byte_pair_encode(const std::string & piece,const Encoder & encoder)231*523fa7a6SAndroid Build Coastguard Worker static std::vector<uint64_t> _byte_pair_encode(
232*523fa7a6SAndroid Build Coastguard Worker     const std::string& piece,
233*523fa7a6SAndroid Build Coastguard Worker     const Encoder& encoder) {
234*523fa7a6SAndroid Build Coastguard Worker   if (piece.size() == 1) {
235*523fa7a6SAndroid Build Coastguard Worker     auto iter = encoder.find(piece);
236*523fa7a6SAndroid Build Coastguard Worker     if (iter != encoder.end()) {
237*523fa7a6SAndroid Build Coastguard Worker       return std::vector<uint64_t>({iter->second});
238*523fa7a6SAndroid Build Coastguard Worker     } else {
239*523fa7a6SAndroid Build Coastguard Worker       // TODO: is it possible?
240*523fa7a6SAndroid Build Coastguard Worker       return {};
241*523fa7a6SAndroid Build Coastguard Worker     }
242*523fa7a6SAndroid Build Coastguard Worker   }
243*523fa7a6SAndroid Build Coastguard Worker 
244*523fa7a6SAndroid Build Coastguard Worker   return _byte_pair_merge(
245*523fa7a6SAndroid Build Coastguard Worker       piece, encoder, [&piece, &encoder](uint64_t start, uint64_t stop) {
246*523fa7a6SAndroid Build Coastguard Worker         std::string key = piece.substr(start, stop - start);
247*523fa7a6SAndroid Build Coastguard Worker         auto iter = encoder.find(key);
248*523fa7a6SAndroid Build Coastguard Worker         if (iter != encoder.end()) {
249*523fa7a6SAndroid Build Coastguard Worker           return iter->second;
250*523fa7a6SAndroid Build Coastguard Worker         } else {
251*523fa7a6SAndroid Build Coastguard Worker           // TODO: what if key does not exist? Should we return `unknown`?
252*523fa7a6SAndroid Build Coastguard Worker           // assert(false); // ??
253*523fa7a6SAndroid Build Coastguard Worker           return uint64_t(0);
254*523fa7a6SAndroid Build Coastguard Worker         }
255*523fa7a6SAndroid Build Coastguard Worker       });
256*523fa7a6SAndroid Build Coastguard Worker }
257*523fa7a6SAndroid Build Coastguard Worker // ------------------------------Util end------------------------------------
258*523fa7a6SAndroid Build Coastguard Worker // -------------------------private method start-------------------------------
259*523fa7a6SAndroid Build Coastguard Worker 
260*523fa7a6SAndroid Build Coastguard Worker template <typename T>
261*523fa7a6SAndroid Build Coastguard Worker std::pair<std::optional<std::string>, re2::StringPiece>
_split_with_allowed_special_token(re2::StringPiece & input,const T & allowed_special) const262*523fa7a6SAndroid Build Coastguard Worker Tiktoken::_split_with_allowed_special_token(
263*523fa7a6SAndroid Build Coastguard Worker     re2::StringPiece& input,
264*523fa7a6SAndroid Build Coastguard Worker     const T& allowed_special) const {
265*523fa7a6SAndroid Build Coastguard Worker   if (!_special_token_regex) {
266*523fa7a6SAndroid Build Coastguard Worker     return std::make_pair(std::nullopt, input);
267*523fa7a6SAndroid Build Coastguard Worker   }
268*523fa7a6SAndroid Build Coastguard Worker 
269*523fa7a6SAndroid Build Coastguard Worker #if __cplusplus >= 202002L
270*523fa7a6SAndroid Build Coastguard Worker   auto start = input.begin();
271*523fa7a6SAndroid Build Coastguard Worker #else
272*523fa7a6SAndroid Build Coastguard Worker   const char* start = input.data();
273*523fa7a6SAndroid Build Coastguard Worker #endif
274*523fa7a6SAndroid Build Coastguard Worker   std::string special;
275*523fa7a6SAndroid Build Coastguard Worker   while (true) {
276*523fa7a6SAndroid Build Coastguard Worker     if (!re2::RE2::FindAndConsume(&input, *_special_token_regex, &special)) {
277*523fa7a6SAndroid Build Coastguard Worker       // No special token.
278*523fa7a6SAndroid Build Coastguard Worker       break;
279*523fa7a6SAndroid Build Coastguard Worker     }
280*523fa7a6SAndroid Build Coastguard Worker 
281*523fa7a6SAndroid Build Coastguard Worker     if (allowed_special.count(special) == 1) {
282*523fa7a6SAndroid Build Coastguard Worker       // Found an allowed special token, split the text with it.
283*523fa7a6SAndroid Build Coastguard Worker #if __cplusplus >= 202002L
284*523fa7a6SAndroid Build Coastguard Worker       return std::make_pair(
285*523fa7a6SAndroid Build Coastguard Worker           special,
286*523fa7a6SAndroid Build Coastguard Worker           re2::StringPiece(start, input.begin() - start - special.size()));
287*523fa7a6SAndroid Build Coastguard Worker #else
288*523fa7a6SAndroid Build Coastguard Worker       return std::make_pair(
289*523fa7a6SAndroid Build Coastguard Worker           special,
290*523fa7a6SAndroid Build Coastguard Worker           re2::StringPiece(start, (input.data() - start) - special.size()));
291*523fa7a6SAndroid Build Coastguard Worker #endif
292*523fa7a6SAndroid Build Coastguard Worker     } // else try to find the next special token
293*523fa7a6SAndroid Build Coastguard Worker   }
294*523fa7a6SAndroid Build Coastguard Worker 
295*523fa7a6SAndroid Build Coastguard Worker   return std::make_pair(std::nullopt, input);
296*523fa7a6SAndroid Build Coastguard Worker }
297*523fa7a6SAndroid Build Coastguard Worker 
_encode(re2::StringPiece & input,std::vector<uint64_t> & ret,uint64_t & last_piece_token_len) const298*523fa7a6SAndroid Build Coastguard Worker void Tiktoken::_encode(
299*523fa7a6SAndroid Build Coastguard Worker     re2::StringPiece& input,
300*523fa7a6SAndroid Build Coastguard Worker     std::vector<uint64_t>& ret,
301*523fa7a6SAndroid Build Coastguard Worker     uint64_t& last_piece_token_len) const {
302*523fa7a6SAndroid Build Coastguard Worker   std::string piece;
303*523fa7a6SAndroid Build Coastguard Worker   assert(_regex);
304*523fa7a6SAndroid Build Coastguard Worker   while (re2::RE2::FindAndConsume(&input, *_regex, &piece)) {
305*523fa7a6SAndroid Build Coastguard Worker     auto iter = _encoder.find(piece);
306*523fa7a6SAndroid Build Coastguard Worker     if (iter != _encoder.end()) {
307*523fa7a6SAndroid Build Coastguard Worker       last_piece_token_len = 1;
308*523fa7a6SAndroid Build Coastguard Worker       ret.push_back(iter->second);
309*523fa7a6SAndroid Build Coastguard Worker       continue;
310*523fa7a6SAndroid Build Coastguard Worker     }
311*523fa7a6SAndroid Build Coastguard Worker     auto tokens = _byte_pair_encode(piece, _encoder);
312*523fa7a6SAndroid Build Coastguard Worker     last_piece_token_len = tokens.size();
313*523fa7a6SAndroid Build Coastguard Worker     ret.insert(ret.end(), tokens.begin(), tokens.end());
314*523fa7a6SAndroid Build Coastguard Worker   }
315*523fa7a6SAndroid Build Coastguard Worker }
316*523fa7a6SAndroid Build Coastguard Worker 
317*523fa7a6SAndroid Build Coastguard Worker template <typename T>
_encode_with_special_token(const std::string & text,const T & allowed_special) const318*523fa7a6SAndroid Build Coastguard Worker std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(
319*523fa7a6SAndroid Build Coastguard Worker     const std::string& text,
320*523fa7a6SAndroid Build Coastguard Worker     const T& allowed_special) const {
321*523fa7a6SAndroid Build Coastguard Worker   std::vector<uint64_t> tokens;
322*523fa7a6SAndroid Build Coastguard Worker   uint64_t last_piece_token_len = 0;
323*523fa7a6SAndroid Build Coastguard Worker   re2::StringPiece input(text);
324*523fa7a6SAndroid Build Coastguard Worker   while (true) {
325*523fa7a6SAndroid Build Coastguard Worker     auto [special, sub_input] =
326*523fa7a6SAndroid Build Coastguard Worker         _split_with_allowed_special_token(input, allowed_special);
327*523fa7a6SAndroid Build Coastguard Worker 
328*523fa7a6SAndroid Build Coastguard Worker     _encode(sub_input, tokens, last_piece_token_len);
329*523fa7a6SAndroid Build Coastguard Worker 
330*523fa7a6SAndroid Build Coastguard Worker     if (special) {
331*523fa7a6SAndroid Build Coastguard Worker       uint64_t token = 0;
332*523fa7a6SAndroid Build Coastguard Worker       try {
333*523fa7a6SAndroid Build Coastguard Worker         token = _special_token_encoder.at(*special);
334*523fa7a6SAndroid Build Coastguard Worker       } catch (const std::out_of_range&) {
335*523fa7a6SAndroid Build Coastguard Worker         // Should never go here, since special pattern includes all special
336*523fa7a6SAndroid Build Coastguard Worker         // chars.
337*523fa7a6SAndroid Build Coastguard Worker         ET_CHECK_MSG(false, "unknown special token: %s", special->c_str());
338*523fa7a6SAndroid Build Coastguard Worker       }
339*523fa7a6SAndroid Build Coastguard Worker 
340*523fa7a6SAndroid Build Coastguard Worker       tokens.push_back(token);
341*523fa7a6SAndroid Build Coastguard Worker       last_piece_token_len = 0;
342*523fa7a6SAndroid Build Coastguard Worker     } else {
343*523fa7a6SAndroid Build Coastguard Worker       break;
344*523fa7a6SAndroid Build Coastguard Worker     }
345*523fa7a6SAndroid Build Coastguard Worker   }
346*523fa7a6SAndroid Build Coastguard Worker 
347*523fa7a6SAndroid Build Coastguard Worker   // last_piece_token_len is how many tokens came from the last regex split.
348*523fa7a6SAndroid Build Coastguard Worker   // This is used for determining unstable tokens, since you can't merge
349*523fa7a6SAndroid Build Coastguard Worker   // across (stable) regex splits
350*523fa7a6SAndroid Build Coastguard Worker   return std::make_pair(tokens, last_piece_token_len);
351*523fa7a6SAndroid Build Coastguard Worker }
352*523fa7a6SAndroid Build Coastguard Worker 
_build_special_token_encoder(ssize_t num_base_tokens) const353*523fa7a6SAndroid Build Coastguard Worker Encoder Tiktoken::_build_special_token_encoder(ssize_t num_base_tokens) const {
354*523fa7a6SAndroid Build Coastguard Worker   Encoder special_token_encoder;
355*523fa7a6SAndroid Build Coastguard Worker   for (ssize_t i = 0; i < _special_tokens->size(); ++i) {
356*523fa7a6SAndroid Build Coastguard Worker     special_token_encoder.emplace(_special_tokens->at(i), num_base_tokens + i);
357*523fa7a6SAndroid Build Coastguard Worker   }
358*523fa7a6SAndroid Build Coastguard Worker   return special_token_encoder;
359*523fa7a6SAndroid Build Coastguard Worker }
360*523fa7a6SAndroid Build Coastguard Worker 
361*523fa7a6SAndroid Build Coastguard Worker // -------------------------private method end-------------------------------
362*523fa7a6SAndroid Build Coastguard Worker // -------------------------public method start-------------------------------
363*523fa7a6SAndroid Build Coastguard Worker 
Tiktoken(std::unique_ptr<std::vector<std::string>> special_tokens,size_t bos_token_index,size_t eos_token_index)364*523fa7a6SAndroid Build Coastguard Worker Tiktoken::Tiktoken(
365*523fa7a6SAndroid Build Coastguard Worker     std::unique_ptr<std::vector<std::string>> special_tokens,
366*523fa7a6SAndroid Build Coastguard Worker     size_t bos_token_index,
367*523fa7a6SAndroid Build Coastguard Worker     size_t eos_token_index)
368*523fa7a6SAndroid Build Coastguard Worker     : Tokenizer(),
369*523fa7a6SAndroid Build Coastguard Worker       _special_tokens(std::move(special_tokens)),
370*523fa7a6SAndroid Build Coastguard Worker       _bos_token_index(bos_token_index),
371*523fa7a6SAndroid Build Coastguard Worker       _eos_token_index(eos_token_index) {
372*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
373*523fa7a6SAndroid Build Coastguard Worker       _bos_token_index < _special_tokens->size(),
374*523fa7a6SAndroid Build Coastguard Worker       "invalid bos_token_index %zu",
375*523fa7a6SAndroid Build Coastguard Worker       _bos_token_index);
376*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_MSG(
377*523fa7a6SAndroid Build Coastguard Worker       _eos_token_index < _special_tokens->size(),
378*523fa7a6SAndroid Build Coastguard Worker       "invalid eos_token_index %zu",
379*523fa7a6SAndroid Build Coastguard Worker       _eos_token_index);
380*523fa7a6SAndroid Build Coastguard Worker }
381*523fa7a6SAndroid Build Coastguard Worker 
load(const std::string & path)382*523fa7a6SAndroid Build Coastguard Worker Error Tiktoken::load(const std::string& path) {
383*523fa7a6SAndroid Build Coastguard Worker   _encoder = ET_UNWRAP(_load_encoder(path));
384*523fa7a6SAndroid Build Coastguard Worker   _special_token_encoder = _build_special_token_encoder(_encoder.size());
385*523fa7a6SAndroid Build Coastguard Worker 
386*523fa7a6SAndroid Build Coastguard Worker   _decoder = ET_UNWRAP(_build_decoder(_encoder));
387*523fa7a6SAndroid Build Coastguard Worker   _special_token_decoder = ET_UNWRAP(_build_decoder(_special_token_encoder));
388*523fa7a6SAndroid Build Coastguard Worker 
389*523fa7a6SAndroid Build Coastguard Worker   _regex = _create_regex(_pattern);
390*523fa7a6SAndroid Build Coastguard Worker   // Warmup re2 as it is slow on the first run, void the return value as it's
391*523fa7a6SAndroid Build Coastguard Worker   // not needed Refer to
392*523fa7a6SAndroid Build Coastguard Worker   // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
393*523fa7a6SAndroid Build Coastguard Worker   (void)_regex->ReverseProgramSize();
394*523fa7a6SAndroid Build Coastguard Worker 
395*523fa7a6SAndroid Build Coastguard Worker   _special_token_regex = _build_special_token_regex(_special_token_encoder);
396*523fa7a6SAndroid Build Coastguard Worker   // Same as above, warm up re2
397*523fa7a6SAndroid Build Coastguard Worker   (void)_special_token_regex->ReverseProgramSize();
398*523fa7a6SAndroid Build Coastguard Worker 
399*523fa7a6SAndroid Build Coastguard Worker   // initialize vocab_size, bos_tok, eos_tok
400*523fa7a6SAndroid Build Coastguard Worker   vocab_size_ = _encoder.size() + _special_token_encoder.size();
401*523fa7a6SAndroid Build Coastguard Worker   bos_tok_ = _special_token_encoder.at(_special_tokens->at(_bos_token_index));
402*523fa7a6SAndroid Build Coastguard Worker   eos_tok_ = _special_token_encoder.at(_special_tokens->at(_eos_token_index));
403*523fa7a6SAndroid Build Coastguard Worker 
404*523fa7a6SAndroid Build Coastguard Worker   initialized_ = true;
405*523fa7a6SAndroid Build Coastguard Worker   return Error::Ok;
406*523fa7a6SAndroid Build Coastguard Worker }
407*523fa7a6SAndroid Build Coastguard Worker 
408*523fa7a6SAndroid Build Coastguard Worker Result<std::vector<uint64_t>>
encode(const std::string & text,int8_t bos,int8_t eos) const409*523fa7a6SAndroid Build Coastguard Worker Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) const {
410*523fa7a6SAndroid Build Coastguard Worker   if (!initialized_) {
411*523fa7a6SAndroid Build Coastguard Worker     return Error::NotSupported;
412*523fa7a6SAndroid Build Coastguard Worker   }
413*523fa7a6SAndroid Build Coastguard Worker   auto res = _encode_with_special_token(text, _special_token_encoder).first;
414*523fa7a6SAndroid Build Coastguard Worker   for (auto i = 0; i < bos; ++i) {
415*523fa7a6SAndroid Build Coastguard Worker     res.insert(res.begin(), bos_tok_);
416*523fa7a6SAndroid Build Coastguard Worker   }
417*523fa7a6SAndroid Build Coastguard Worker   for (auto i = 0; i < eos; ++i) {
418*523fa7a6SAndroid Build Coastguard Worker     res.push_back(eos_tok_);
419*523fa7a6SAndroid Build Coastguard Worker   }
420*523fa7a6SAndroid Build Coastguard Worker   return Result<std::vector<uint64_t>>(std::move(res));
421*523fa7a6SAndroid Build Coastguard Worker }
422*523fa7a6SAndroid Build Coastguard Worker 
decode(uint64_t prev,uint64_t cur) const423*523fa7a6SAndroid Build Coastguard Worker Result<std::string> Tiktoken::decode(uint64_t prev, uint64_t cur) const {
424*523fa7a6SAndroid Build Coastguard Worker   (void)prev;
425*523fa7a6SAndroid Build Coastguard Worker   ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(cur));
426*523fa7a6SAndroid Build Coastguard Worker   std::string ret;
427*523fa7a6SAndroid Build Coastguard Worker 
428*523fa7a6SAndroid Build Coastguard Worker   std::string token_bytes;
429*523fa7a6SAndroid Build Coastguard Worker   auto iter = _decoder.find(cur);
430*523fa7a6SAndroid Build Coastguard Worker   if (iter != _decoder.end()) {
431*523fa7a6SAndroid Build Coastguard Worker     token_bytes = iter->second;
432*523fa7a6SAndroid Build Coastguard Worker   } else {
433*523fa7a6SAndroid Build Coastguard Worker     iter = _special_token_decoder.find(cur);
434*523fa7a6SAndroid Build Coastguard Worker     if (iter != _special_token_decoder.end()) {
435*523fa7a6SAndroid Build Coastguard Worker       token_bytes = iter->second;
436*523fa7a6SAndroid Build Coastguard Worker     } else {
437*523fa7a6SAndroid Build Coastguard Worker       ET_CHECK_MSG(false, "unknown token: %" PRIu64, cur);
438*523fa7a6SAndroid Build Coastguard Worker     }
439*523fa7a6SAndroid Build Coastguard Worker   }
440*523fa7a6SAndroid Build Coastguard Worker   ret += token_bytes;
441*523fa7a6SAndroid Build Coastguard Worker 
442*523fa7a6SAndroid Build Coastguard Worker   return ret;
443*523fa7a6SAndroid Build Coastguard Worker }
444*523fa7a6SAndroid Build Coastguard Worker // -------------------------public method end-------------------------------
445*523fa7a6SAndroid Build Coastguard Worker 
446*523fa7a6SAndroid Build Coastguard Worker } // namespace llm
447*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
448*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
449