1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import os 8from logging import getLogger 9from pathlib import Path 10from typing import ( 11 AbstractSet, 12 cast, 13 Collection, 14 Dict, 15 Iterator, 16 List, 17 Literal, 18 Optional, 19 Sequence, 20 Union, 21) 22 23import tiktoken 24 25from tiktoken.load import load_tiktoken_bpe 26 27logger = getLogger(__name__) 28 29 30# The tiktoken tokenizer can handle <=400k chars without 31# pyo3_runtime.PanicException. 32TIKTOKEN_MAX_ENCODE_CHARS = 400_000 33 34# https://github.com/openai/tiktoken/issues/195 35# Here we iterate over subsequences and split if we exceed the limit 36# of max consecutive non-whitespace or whitespace characters. 37MAX_NO_WHITESPACES_CHARS = 25_000 38 39 40_INSTANCE = None 41 42 43class Tokenizer: 44 """ 45 Tokenizing and encoding/decoding text using the Tiktoken tokenizer. 46 """ 47 48 special_tokens: Dict[str, int] 49 50 num_reserved_special_tokens = 256 51 52 pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 53 54 @classmethod 55 def get_instance(cls): 56 global _INSTANCE 57 58 if _INSTANCE is None: 59 _INSTANCE = Tokenizer( 60 os.path.join(os.path.dirname(__file__), "tokenizer.model") 61 ) 62 return _INSTANCE 63 64 def __init__(self, model_path: str): 65 """ 66 Initializes the Tokenizer with a Tiktoken model. 67 68 Args: 69 model_path (str): The path to the Tiktoken model file. 70 """ 71 assert os.path.isfile(model_path), model_path 72 73 mergeable_ranks = load_tiktoken_bpe(model_path) 74 num_base_tokens = len(mergeable_ranks) 75 special_tokens = [ 76 "<|begin_of_text|>", 77 "<|end_of_text|>", 78 "<|reserved_special_token_0|>", 79 "<|reserved_special_token_1|>", 80 "<|finetune_right_pad_id|>", 81 "<|step_id|>", 82 "<|start_header_id|>", 83 "<|end_header_id|>", 84 "<|eom_id|>", # end of message 85 "<|eot_id|>", # end of turn 86 "<|python_tag|>", 87 "<|image|>", 88 ] 89 reserved_tokens = [ 90 f"<|reserved_special_token_{2 + i}|>" 91 for i in range(self.num_reserved_special_tokens - len(special_tokens)) 92 ] 93 special_tokens = special_tokens + reserved_tokens 94 95 self.special_tokens = { 96 token: num_base_tokens + i for i, token in enumerate(special_tokens) 97 } 98 self.model = tiktoken.Encoding( 99 name=Path(model_path).name, 100 pat_str=self.pat_str, 101 mergeable_ranks=mergeable_ranks, 102 special_tokens=self.special_tokens, 103 ) 104 105 self.n_words: int = num_base_tokens + len(special_tokens) 106 # BOS / EOS token IDs 107 self.bos_id: int = self.special_tokens["<|begin_of_text|>"] 108 self.eos_id: int = self.special_tokens["<|end_of_text|>"] 109 self.eot_id: int = self.special_tokens["<|eot_id|>"] 110 self.eom_id: int = self.special_tokens["<|eom_id|>"] 111 self.python_tag_id = self.special_tokens["<|python_tag|>"] 112 self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] 113 self.stop_tokens = [ 114 self.eos_id, 115 self.special_tokens["<|eom_id|>"], 116 self.special_tokens["<|eot_id|>"], 117 ] 118 119 def encode( 120 self, 121 s: str, 122 *, 123 bos: bool, 124 eos: bool, 125 allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, 126 disallowed_special: Union[Literal["all"], Collection[str]] = (), 127 ) -> List[int]: 128 """ 129 Encodes a string into a list of token IDs. 130 131 Args: 132 s (str): The input string to be encoded. 133 bos (bool): Whether to prepend the beginning-of-sequence token. 134 eos (bool): Whether to append the end-of-sequence token. 135 allowed_special ("all"|set[str]): allowed special tokens in string 136 disallowed_special ("all"|set[str]): special tokens that raise an error when in string 137 138 Returns: 139 list[int]: A list of token IDs. 140 141 By default, setting disallowed_special=() encodes a string by ignoring 142 special tokens. Specifically: 143 - Setting `disallowed_special` to () will cause all text corresponding 144 to special tokens to be encoded as natural text (insteading of raising 145 an error). 146 - Setting `allowed_special` to "all" will treat all text corresponding 147 to special tokens to be encoded as special tokens. 148 """ 149 if allowed_special is None: 150 allowed_special = set() 151 assert type(s) is str 152 153 substrs = ( 154 substr 155 for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) 156 for substr in self._split_whitespaces_or_nonwhitespaces( 157 s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS 158 ) 159 ) 160 t: List[int] = [] 161 for substr in substrs: 162 t.extend( 163 self.model.encode( 164 substr, 165 allowed_special=allowed_special, 166 disallowed_special=disallowed_special, 167 ) 168 ) 169 if bos: 170 t.insert(0, self.bos_id) 171 if eos: 172 t.append(self.eos_id) 173 return t 174 175 def decode(self, t: Sequence[int]) -> str: 176 """ 177 Decodes a list of token IDs into a string. 178 179 Args: 180 t (List[int]): The list of token IDs to be decoded. 181 182 Returns: 183 str: The decoded string. 184 """ 185 # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. 186 return self.model.decode(cast(List[int], t)) 187 188 def decode_token(self, t: int) -> str: 189 """ 190 Decodes a single token ID into a string. 191 192 Args: 193 t (int): The token ID to be decoded. 194 195 Returns: 196 str: The decoded string. 197 """ 198 return self.model.decode_single_token_bytes(t).decode("utf-8") 199 200 @staticmethod 201 def _split_whitespaces_or_nonwhitespaces( 202 s: str, max_consecutive_slice_len: int 203 ) -> Iterator[str]: 204 """ 205 Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` 206 consecutive whitespaces or consecutive non-whitespaces. 207 """ 208 current_slice_len = 0 209 current_slice_is_space = s[0].isspace() if len(s) > 0 else False 210 slice_start = 0 211 212 for i in range(len(s)): 213 is_now_space = s[i].isspace() 214 215 if current_slice_is_space ^ is_now_space: 216 current_slice_len = 1 217 current_slice_is_space = is_now_space 218 else: 219 current_slice_len += 1 220 if current_slice_len > max_consecutive_slice_len: 221 yield s[slice_start:i] 222 slice_start = i 223 current_slice_len = 1 224 yield s[slice_start:] 225