1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/extension/llm/tokenizer/tokenizer.h> 12 #include <re2/re2.h> 13 #include <memory> 14 #include <optional> 15 #include <unordered_map> 16 17 namespace executorch { 18 namespace extension { 19 namespace llm { 20 21 using Encoder = std::unordered_map<std::string, uint64_t>; 22 using Decoder = std::unordered_map<uint64_t, std::string>; 23 using Re2UPtr = std::unique_ptr<re2::RE2>; 24 25 class ET_EXPERIMENTAL Tiktoken : public Tokenizer { 26 public: 27 /** 28 * @param[in] special_tokens List of special tokens including bos, eos; 29 * @param[in] bos_token_index Index of the bos token in special_tokens; 30 * @param[in] eos_token_index Index of the eos token in special_tokens. 31 */ 32 explicit Tiktoken( 33 std::unique_ptr<std::vector<std::string>> special_tokens, 34 size_t bos_token_index, 35 size_t eos_token_index); 36 37 ::executorch::runtime::Error load(const std::string& tokenizer_path) override; 38 39 ::executorch::runtime::Result<std::vector<uint64_t>> 40 encode(const std::string& input, int8_t bos, int8_t eos) const override; 41 42 ::executorch::runtime::Result<std::string> decode( 43 uint64_t prev_token, 44 uint64_t token) const override; 45 46 private: 47 template <typename T> 48 std::pair<std::optional<std::string>, re2::StringPiece> 49 _split_with_allowed_special_token( 50 re2::StringPiece& input, 51 const T& allowed_special) const; 52 53 void _encode( 54 re2::StringPiece& input, 55 std::vector<uint64_t>& ret, 56 uint64_t& last_piece_token_len) const; 57 58 template <typename T> 59 std::pair<std::vector<uint64_t>, uint64_t> _encode_with_special_token( 60 const std::string& text, 61 const T& allowed_special) const; 62 63 Encoder _build_special_token_encoder(ssize_t num_base_tokens) const; 64 65 std::unique_ptr<std::vector<std::string>> _special_tokens; 66 size_t _bos_token_index; 67 size_t _eos_token_index; 68 // Removed negative lookahead \s+(?!\S) since it's not supported by RE2. 69 const std::string _pattern = 70 R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"; 71 Encoder _encoder; 72 Encoder _special_token_encoder; 73 Decoder _decoder; 74 Decoder _special_token_decoder; 75 76 Re2UPtr _regex; 77 Re2UPtr _special_token_regex; 78 }; 79 80 } // namespace llm 81 } // namespace extension 82 } // namespace executorch 83 84 namespace torch { 85 namespace executor { 86 // TODO(T197294990): Remove these deprecated aliases once all users have moved 87 // to the new `::executorch` namespaces. 88 using ::executorch::extension::llm::Decoder; 89 using ::executorch::extension::llm::Encoder; 90 using ::executorch::extension::llm::Re2UPtr; 91 using ::executorch::extension::llm::Tiktoken; 92 } // namespace executor 93 } // namespace torch 94