1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/llm/tokenizer/tiktoken.h>
10 #include <executorch/runtime/platform/runtime.h>
11 #include <gmock/gmock.h>
12 #include <gtest/gtest.h>
13 #include <sstream>
14 #include <vector>
15
16 using namespace ::testing;
17 using ::executorch::extension::llm::Tiktoken;
18 using ::executorch::extension::llm::Tokenizer;
19 using ::executorch::runtime::Error;
20 using ::executorch::runtime::Result;
21
22 namespace {
23 // Test case based on Llama 2
24 static constexpr int32_t kSpecialTokensSize = 256;
25 static constexpr size_t kBOSTokenIndex = 0;
26 static constexpr size_t kEOSTokenIndex = 1;
_get_special_tokens()27 static inline std::unique_ptr<std::vector<std::string>> _get_special_tokens() {
28 auto special_tokens =
29 std::make_unique<std::vector<std::string>>(std::vector<std::string>{
30 "<|begin_of_text|>",
31 "<|end_of_text|>",
32 "<|reserved_special_token_0|>",
33 "<|reserved_special_token_1|>",
34 "<|reserved_special_token_2|>",
35 "<|reserved_special_token_3|>",
36 "<|start_header_id|>",
37 "<|end_header_id|>",
38 "<|reserved_special_token_4|>",
39 "<|eot_id|>"});
40
41 // pad the rest of the special tokens with reserved tokens
42 ssize_t reserved_special_token_num = 5;
43 while (special_tokens->size() < kSpecialTokensSize) {
44 special_tokens->emplace_back(
45 "<|reserved_special_token_" +
46 std::to_string(reserved_special_token_num++) + "|>");
47 }
48 return special_tokens;
49 }
50 } // namespace
51
52 class TiktokenExtensionTest : public Test {
53 public:
SetUp()54 void SetUp() override {
55 executorch::runtime::runtime_init();
56 tokenizer_ = std::make_unique<Tiktoken>(
57 _get_special_tokens(), kBOSTokenIndex, kEOSTokenIndex);
58 modelPath_ = std::getenv("RESOURCES_PATH") +
59 std::string("/test_tiktoken_tokenizer.model");
60 }
61
62 std::unique_ptr<Tokenizer> tokenizer_;
63 std::string modelPath_;
64 };
65
TEST_F(TiktokenExtensionTest,EncodeWithoutLoadFails)66 TEST_F(TiktokenExtensionTest, EncodeWithoutLoadFails) {
67 Result<std::vector<uint64_t>> res = tokenizer_->encode("hello world", 0, 0);
68 EXPECT_EQ(res.error(), Error::NotSupported);
69 }
70
TEST_F(TiktokenExtensionTest,DecodeWithoutLoadFails)71 TEST_F(TiktokenExtensionTest, DecodeWithoutLoadFails) {
72 auto result = tokenizer_->decode(0, 0);
73 EXPECT_EQ(result.error(), Error::NotSupported);
74 }
75
TEST_F(TiktokenExtensionTest,TokenizerVocabSizeIsExpected)76 TEST_F(TiktokenExtensionTest, TokenizerVocabSizeIsExpected) {
77 Error res = tokenizer_->load(modelPath_.c_str());
78 EXPECT_EQ(res, Error::Ok);
79 EXPECT_EQ(tokenizer_->vocab_size(), 128256);
80 EXPECT_EQ(tokenizer_->bos_tok(), 128000);
81 EXPECT_EQ(tokenizer_->eos_tok(), 128001);
82 }
83
TEST_F(TiktokenExtensionTest,TokenizerEncodeCorrectly)84 TEST_F(TiktokenExtensionTest, TokenizerEncodeCorrectly) {
85 Error res = tokenizer_->load(modelPath_.c_str());
86 EXPECT_EQ(res, Error::Ok);
87 Result<std::vector<uint64_t>> out = tokenizer_->encode("hello world", 1, 0);
88 EXPECT_EQ(out.error(), Error::Ok);
89 EXPECT_EQ(out.get().size(), 3);
90 EXPECT_EQ(out.get()[0], 128000);
91 EXPECT_EQ(out.get()[1], 15339);
92 EXPECT_EQ(out.get()[2], 1917);
93 }
94
TEST_F(TiktokenExtensionTest,TokenizerDecodeCorrectly)95 TEST_F(TiktokenExtensionTest, TokenizerDecodeCorrectly) {
96 Error res = tokenizer_->load(modelPath_.c_str());
97 EXPECT_EQ(res, Error::Ok);
98 std::vector<std::string> expected = {"<|begin_of_text|>", "hello", " world"};
99 std::vector<uint64_t> tokens = {128000, 15339, 1917};
100 for (size_t i = 0; i < tokens.size(); i++) {
101 Result<std::string> out = tokenizer_->decode(0, tokens[i]);
102 EXPECT_EQ(out.error(), Error::Ok);
103 EXPECT_EQ(out.get(), expected[i]);
104 }
105 }
106
TEST_F(TiktokenExtensionTest,TokenizerDecodeOutOfRangeFails)107 TEST_F(TiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) {
108 Error res = tokenizer_->load(modelPath_.c_str());
109 EXPECT_EQ(res, Error::Ok);
110 // The vocab size is 128256, addes 256 just so the token is out of vocab
111 // range.
112 Result<std::string> out = tokenizer_->decode(0, 128256 + 256);
113 EXPECT_EQ(out.error(), Error::NotSupported);
114 }
115
TEST_F(TiktokenExtensionTest,ConstructionWithInvalidBOSIndex)116 TEST_F(TiktokenExtensionTest, ConstructionWithInvalidBOSIndex) {
117 // gtest death test doesn't work on iOS:
118 // https://github.com/google/googletest/issues/2834
119 #if !GTEST_OS_IOS
120 EXPECT_EXIT(
121 std::make_unique<Tiktoken>(
122 std::make_unique<std::vector<std::string>>(
123 std::vector<std::string>{"<|end_of_text|>"}),
124 1,
125 0),
126 ::testing::KilledBySignal(SIGABRT),
127 "");
128 #endif
129 }
130
TEST_F(TiktokenExtensionTest,ConstructionWithInvalidEOSIndex)131 TEST_F(TiktokenExtensionTest, ConstructionWithInvalidEOSIndex) {
132 // gtest death test doesn't work on iOS:
133 // https://github.com/google/googletest/issues/2834
134 #if !GTEST_OS_IOS
135 EXPECT_EXIT(
136 std::make_unique<Tiktoken>(
137 std::make_unique<std::vector<std::string>>(
138 std::vector<std::string>{"<|begin_of_text|>"}),
139 0,
140 1),
141 ::testing::KilledBySignal(SIGABRT),
142 "");
143 #endif
144 }
145
TEST_F(TiktokenExtensionTest,LoadWithInvalidPath)146 TEST_F(TiktokenExtensionTest, LoadWithInvalidPath) {
147 auto invalidModelPath =
148 std::getenv("RESOURCES_PATH") + std::string("/nonexistent.model");
149
150 Error res = tokenizer_->load(invalidModelPath.c_str());
151 EXPECT_EQ(res, Error::InvalidArgument);
152 }
153
TEST_F(TiktokenExtensionTest,LoadTiktokenFileWithInvalidRank)154 TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidRank) {
155 auto invalidModelPath = std::getenv("RESOURCES_PATH") +
156 std::string("/test_tiktoken_invalid_rank.model");
157
158 Error res = tokenizer_->load(invalidModelPath.c_str());
159
160 EXPECT_EQ(res, Error::InvalidArgument);
161 }
162
TEST_F(TiktokenExtensionTest,LoadTiktokenFileWithInvalidBase64)163 TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidBase64) {
164 auto invalidModelPath = std::getenv("RESOURCES_PATH") +
165 std::string("/test_tiktoken_invalid_base64.model");
166
167 Error res = tokenizer_->load(invalidModelPath.c_str());
168
169 EXPECT_EQ(res, Error::InvalidArgument);
170 }
171
TEST_F(TiktokenExtensionTest,LoadTiktokenFileWithNoSpace)172 TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithNoSpace) {
173 auto invalidModelPath = std::getenv("RESOURCES_PATH") +
174 std::string("/test_tiktoken_no_space.model");
175
176 Error res = tokenizer_->load(invalidModelPath.c_str());
177
178 EXPECT_EQ(res, Error::InvalidArgument);
179 }
180
TEST_F(TiktokenExtensionTest,LoadTiktokenFileWithBPEFile)181 TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithBPEFile) {
182 auto invalidModelPath =
183 std::getenv("RESOURCES_PATH") + std::string("/test_bpe_tokenizer.bin");
184
185 Error res = tokenizer_->load(invalidModelPath.c_str());
186
187 EXPECT_EQ(res, Error::InvalidArgument);
188 }
189