xref: /aosp_15_r20/external/executorch/examples/models/llama/tokenizer/tiktoken.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import os
8from logging import getLogger
9from pathlib import Path
10from typing import (
11    AbstractSet,
12    cast,
13    Collection,
14    Dict,
15    Iterator,
16    List,
17    Literal,
18    Optional,
19    Sequence,
20    Union,
21)
22
23import tiktoken
24
25from tiktoken.load import load_tiktoken_bpe
26
27logger = getLogger(__name__)
28
29
30# The tiktoken tokenizer can handle <=400k chars without
31# pyo3_runtime.PanicException.
32TIKTOKEN_MAX_ENCODE_CHARS = 400_000
33
34# https://github.com/openai/tiktoken/issues/195
35# Here we iterate over subsequences and split if we exceed the limit
36# of max consecutive non-whitespace or whitespace characters.
37MAX_NO_WHITESPACES_CHARS = 25_000
38
39
40_INSTANCE = None
41
42
43class Tokenizer:
44    """
45    Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
46    """
47
48    special_tokens: Dict[str, int]
49
50    num_reserved_special_tokens = 256
51
52    pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"  # noqa: E501
53
54    @classmethod
55    def get_instance(cls):
56        global _INSTANCE
57
58        if _INSTANCE is None:
59            _INSTANCE = Tokenizer(
60                os.path.join(os.path.dirname(__file__), "tokenizer.model")
61            )
62        return _INSTANCE
63
64    def __init__(self, model_path: str):
65        """
66        Initializes the Tokenizer with a Tiktoken model.
67
68        Args:
69            model_path (str): The path to the Tiktoken model file.
70        """
71        assert os.path.isfile(model_path), model_path
72
73        mergeable_ranks = load_tiktoken_bpe(model_path)
74        num_base_tokens = len(mergeable_ranks)
75        special_tokens = [
76            "<|begin_of_text|>",
77            "<|end_of_text|>",
78            "<|reserved_special_token_0|>",
79            "<|reserved_special_token_1|>",
80            "<|finetune_right_pad_id|>",
81            "<|step_id|>",
82            "<|start_header_id|>",
83            "<|end_header_id|>",
84            "<|eom_id|>",  # end of message
85            "<|eot_id|>",  # end of turn
86            "<|python_tag|>",
87            "<|image|>",
88        ]
89        reserved_tokens = [
90            f"<|reserved_special_token_{2 + i}|>"
91            for i in range(self.num_reserved_special_tokens - len(special_tokens))
92        ]
93        special_tokens = special_tokens + reserved_tokens
94
95        self.special_tokens = {
96            token: num_base_tokens + i for i, token in enumerate(special_tokens)
97        }
98        self.model = tiktoken.Encoding(
99            name=Path(model_path).name,
100            pat_str=self.pat_str,
101            mergeable_ranks=mergeable_ranks,
102            special_tokens=self.special_tokens,
103        )
104
105        self.n_words: int = num_base_tokens + len(special_tokens)
106        # BOS / EOS token IDs
107        self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
108        self.eos_id: int = self.special_tokens["<|end_of_text|>"]
109        self.eot_id: int = self.special_tokens["<|eot_id|>"]
110        self.eom_id: int = self.special_tokens["<|eom_id|>"]
111        self.python_tag_id = self.special_tokens["<|python_tag|>"]
112        self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
113        self.stop_tokens = [
114            self.eos_id,
115            self.special_tokens["<|eom_id|>"],
116            self.special_tokens["<|eot_id|>"],
117        ]
118
119    def encode(
120        self,
121        s: str,
122        *,
123        bos: bool,
124        eos: bool,
125        allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
126        disallowed_special: Union[Literal["all"], Collection[str]] = (),
127    ) -> List[int]:
128        """
129        Encodes a string into a list of token IDs.
130
131        Args:
132            s (str): The input string to be encoded.
133            bos (bool): Whether to prepend the beginning-of-sequence token.
134            eos (bool): Whether to append the end-of-sequence token.
135            allowed_special ("all"|set[str]): allowed special tokens in string
136            disallowed_special ("all"|set[str]): special tokens that raise an error when in string
137
138        Returns:
139            list[int]: A list of token IDs.
140
141        By default, setting disallowed_special=() encodes a string by ignoring
142        special tokens. Specifically:
143        - Setting `disallowed_special` to () will cause all text corresponding
144          to special tokens to be encoded as natural text (insteading of raising
145          an error).
146        - Setting `allowed_special` to "all" will treat all text corresponding
147          to special tokens to be encoded as special tokens.
148        """
149        if allowed_special is None:
150            allowed_special = set()
151        assert type(s) is str
152
153        substrs = (
154            substr
155            for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
156            for substr in self._split_whitespaces_or_nonwhitespaces(
157                s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
158            )
159        )
160        t: List[int] = []
161        for substr in substrs:
162            t.extend(
163                self.model.encode(
164                    substr,
165                    allowed_special=allowed_special,
166                    disallowed_special=disallowed_special,
167                )
168            )
169        if bos:
170            t.insert(0, self.bos_id)
171        if eos:
172            t.append(self.eos_id)
173        return t
174
175    def decode(self, t: Sequence[int]) -> str:
176        """
177        Decodes a list of token IDs into a string.
178
179        Args:
180            t (List[int]): The list of token IDs to be decoded.
181
182        Returns:
183            str: The decoded string.
184        """
185        # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
186        return self.model.decode(cast(List[int], t))
187
188    def decode_token(self, t: int) -> str:
189        """
190        Decodes a single token ID into a string.
191
192        Args:
193            t (int): The token ID to be decoded.
194
195        Returns:
196            str: The decoded string.
197        """
198        return self.model.decode_single_token_bytes(t).decode("utf-8")
199
200    @staticmethod
201    def _split_whitespaces_or_nonwhitespaces(
202        s: str, max_consecutive_slice_len: int
203    ) -> Iterator[str]:
204        """
205        Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
206        consecutive whitespaces or consecutive non-whitespaces.
207        """
208        current_slice_len = 0
209        current_slice_is_space = s[0].isspace() if len(s) > 0 else False
210        slice_start = 0
211
212        for i in range(len(s)):
213            is_now_space = s[i].isspace()
214
215            if current_slice_is_space ^ is_now_space:
216                current_slice_len = 1
217                current_slice_is_space = is_now_space
218            else:
219                current_slice_len += 1
220                if current_slice_len > max_consecutive_slice_len:
221                    yield s[slice_start:i]
222                    slice_start = i
223                    current_slice_len = 1
224        yield s[slice_start:]
225