1*523fa7a6SAndroid Build Coastguard Worker /* 2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates. 3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved. 4*523fa7a6SAndroid Build Coastguard Worker * 5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the 6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree. 7*523fa7a6SAndroid Build Coastguard Worker */ 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker #pragma once 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/tokenizer/tokenizer.h> 12*523fa7a6SAndroid Build Coastguard Worker #include <re2/re2.h> 13*523fa7a6SAndroid Build Coastguard Worker #include <memory> 14*523fa7a6SAndroid Build Coastguard Worker #include <optional> 15*523fa7a6SAndroid Build Coastguard Worker #include <unordered_map> 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker namespace executorch { 18*523fa7a6SAndroid Build Coastguard Worker namespace extension { 19*523fa7a6SAndroid Build Coastguard Worker namespace llm { 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Worker using Encoder = std::unordered_map<std::string, uint64_t>; 22*523fa7a6SAndroid Build Coastguard Worker using Decoder = std::unordered_map<uint64_t, std::string>; 23*523fa7a6SAndroid Build Coastguard Worker using Re2UPtr = std::unique_ptr<re2::RE2>; 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Worker class ET_EXPERIMENTAL Tiktoken : public Tokenizer { 26*523fa7a6SAndroid Build Coastguard Worker public: 27*523fa7a6SAndroid Build Coastguard Worker /** 28*523fa7a6SAndroid Build Coastguard Worker * @param[in] special_tokens List of special tokens including bos, eos; 29*523fa7a6SAndroid Build Coastguard Worker * @param[in] bos_token_index Index of the bos token in special_tokens; 30*523fa7a6SAndroid Build Coastguard Worker * @param[in] eos_token_index Index of the eos token in special_tokens. 31*523fa7a6SAndroid Build Coastguard Worker */ 32*523fa7a6SAndroid Build Coastguard Worker explicit Tiktoken( 33*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<std::vector<std::string>> special_tokens, 34*523fa7a6SAndroid Build Coastguard Worker size_t bos_token_index, 35*523fa7a6SAndroid Build Coastguard Worker size_t eos_token_index); 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker ::executorch::runtime::Error load(const std::string& tokenizer_path) override; 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker ::executorch::runtime::Result<std::vector<uint64_t>> 40*523fa7a6SAndroid Build Coastguard Worker encode(const std::string& input, int8_t bos, int8_t eos) const override; 41*523fa7a6SAndroid Build Coastguard Worker 42*523fa7a6SAndroid Build Coastguard Worker ::executorch::runtime::Result<std::string> decode( 43*523fa7a6SAndroid Build Coastguard Worker uint64_t prev_token, 44*523fa7a6SAndroid Build Coastguard Worker uint64_t token) const override; 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker private: 47*523fa7a6SAndroid Build Coastguard Worker template <typename T> 48*523fa7a6SAndroid Build Coastguard Worker std::pair<std::optional<std::string>, re2::StringPiece> 49*523fa7a6SAndroid Build Coastguard Worker _split_with_allowed_special_token( 50*523fa7a6SAndroid Build Coastguard Worker re2::StringPiece& input, 51*523fa7a6SAndroid Build Coastguard Worker const T& allowed_special) const; 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker void _encode( 54*523fa7a6SAndroid Build Coastguard Worker re2::StringPiece& input, 55*523fa7a6SAndroid Build Coastguard Worker std::vector<uint64_t>& ret, 56*523fa7a6SAndroid Build Coastguard Worker uint64_t& last_piece_token_len) const; 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Worker template <typename T> 59*523fa7a6SAndroid Build Coastguard Worker std::pair<std::vector<uint64_t>, uint64_t> _encode_with_special_token( 60*523fa7a6SAndroid Build Coastguard Worker const std::string& text, 61*523fa7a6SAndroid Build Coastguard Worker const T& allowed_special) const; 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Worker Encoder _build_special_token_encoder(ssize_t num_base_tokens) const; 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<std::vector<std::string>> _special_tokens; 66*523fa7a6SAndroid Build Coastguard Worker size_t _bos_token_index; 67*523fa7a6SAndroid Build Coastguard Worker size_t _eos_token_index; 68*523fa7a6SAndroid Build Coastguard Worker // Removed negative lookahead \s+(?!\S) since it's not supported by RE2. 69*523fa7a6SAndroid Build Coastguard Worker const std::string _pattern = 70*523fa7a6SAndroid Build Coastguard Worker R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"; 71*523fa7a6SAndroid Build Coastguard Worker Encoder _encoder; 72*523fa7a6SAndroid Build Coastguard Worker Encoder _special_token_encoder; 73*523fa7a6SAndroid Build Coastguard Worker Decoder _decoder; 74*523fa7a6SAndroid Build Coastguard Worker Decoder _special_token_decoder; 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker Re2UPtr _regex; 77*523fa7a6SAndroid Build Coastguard Worker Re2UPtr _special_token_regex; 78*523fa7a6SAndroid Build Coastguard Worker }; 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Worker } // namespace llm 81*523fa7a6SAndroid Build Coastguard Worker } // namespace extension 82*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch 83*523fa7a6SAndroid Build Coastguard Worker 84*523fa7a6SAndroid Build Coastguard Worker namespace torch { 85*523fa7a6SAndroid Build Coastguard Worker namespace executor { 86*523fa7a6SAndroid Build Coastguard Worker // TODO(T197294990): Remove these deprecated aliases once all users have moved 87*523fa7a6SAndroid Build Coastguard Worker // to the new `::executorch` namespaces. 88*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::Decoder; 89*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::Encoder; 90*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::Re2UPtr; 91*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::Tiktoken; 92*523fa7a6SAndroid Build Coastguard Worker } // namespace executor 93*523fa7a6SAndroid Build Coastguard Worker } // namespace torch 94