xref: /aosp_15_r20/external/executorch/extension/llm/tokenizer/test/test_tiktoken.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/llm/tokenizer/tiktoken.h>
10 #include <executorch/runtime/platform/runtime.h>
11 #include <gmock/gmock.h>
12 #include <gtest/gtest.h>
13 #include <sstream>
14 #include <vector>
15 
16 using namespace ::testing;
17 using ::executorch::extension::llm::Tiktoken;
18 using ::executorch::extension::llm::Tokenizer;
19 using ::executorch::runtime::Error;
20 using ::executorch::runtime::Result;
21 
22 namespace {
23 // Test case based on Llama 2
24 static constexpr int32_t kSpecialTokensSize = 256;
25 static constexpr size_t kBOSTokenIndex = 0;
26 static constexpr size_t kEOSTokenIndex = 1;
_get_special_tokens()27 static inline std::unique_ptr<std::vector<std::string>> _get_special_tokens() {
28   auto special_tokens =
29       std::make_unique<std::vector<std::string>>(std::vector<std::string>{
30           "<|begin_of_text|>",
31           "<|end_of_text|>",
32           "<|reserved_special_token_0|>",
33           "<|reserved_special_token_1|>",
34           "<|reserved_special_token_2|>",
35           "<|reserved_special_token_3|>",
36           "<|start_header_id|>",
37           "<|end_header_id|>",
38           "<|reserved_special_token_4|>",
39           "<|eot_id|>"});
40 
41   // pad the rest of the special tokens with reserved tokens
42   ssize_t reserved_special_token_num = 5;
43   while (special_tokens->size() < kSpecialTokensSize) {
44     special_tokens->emplace_back(
45         "<|reserved_special_token_" +
46         std::to_string(reserved_special_token_num++) + "|>");
47   }
48   return special_tokens;
49 }
50 } // namespace
51 
52 class TiktokenExtensionTest : public Test {
53  public:
SetUp()54   void SetUp() override {
55     executorch::runtime::runtime_init();
56     tokenizer_ = std::make_unique<Tiktoken>(
57         _get_special_tokens(), kBOSTokenIndex, kEOSTokenIndex);
58     modelPath_ = std::getenv("RESOURCES_PATH") +
59         std::string("/test_tiktoken_tokenizer.model");
60   }
61 
62   std::unique_ptr<Tokenizer> tokenizer_;
63   std::string modelPath_;
64 };
65 
TEST_F(TiktokenExtensionTest,EncodeWithoutLoadFails)66 TEST_F(TiktokenExtensionTest, EncodeWithoutLoadFails) {
67   Result<std::vector<uint64_t>> res = tokenizer_->encode("hello world", 0, 0);
68   EXPECT_EQ(res.error(), Error::NotSupported);
69 }
70 
TEST_F(TiktokenExtensionTest,DecodeWithoutLoadFails)71 TEST_F(TiktokenExtensionTest, DecodeWithoutLoadFails) {
72   auto result = tokenizer_->decode(0, 0);
73   EXPECT_EQ(result.error(), Error::NotSupported);
74 }
75 
TEST_F(TiktokenExtensionTest,TokenizerVocabSizeIsExpected)76 TEST_F(TiktokenExtensionTest, TokenizerVocabSizeIsExpected) {
77   Error res = tokenizer_->load(modelPath_.c_str());
78   EXPECT_EQ(res, Error::Ok);
79   EXPECT_EQ(tokenizer_->vocab_size(), 128256);
80   EXPECT_EQ(tokenizer_->bos_tok(), 128000);
81   EXPECT_EQ(tokenizer_->eos_tok(), 128001);
82 }
83 
TEST_F(TiktokenExtensionTest,TokenizerEncodeCorrectly)84 TEST_F(TiktokenExtensionTest, TokenizerEncodeCorrectly) {
85   Error res = tokenizer_->load(modelPath_.c_str());
86   EXPECT_EQ(res, Error::Ok);
87   Result<std::vector<uint64_t>> out = tokenizer_->encode("hello world", 1, 0);
88   EXPECT_EQ(out.error(), Error::Ok);
89   EXPECT_EQ(out.get().size(), 3);
90   EXPECT_EQ(out.get()[0], 128000);
91   EXPECT_EQ(out.get()[1], 15339);
92   EXPECT_EQ(out.get()[2], 1917);
93 }
94 
TEST_F(TiktokenExtensionTest,TokenizerDecodeCorrectly)95 TEST_F(TiktokenExtensionTest, TokenizerDecodeCorrectly) {
96   Error res = tokenizer_->load(modelPath_.c_str());
97   EXPECT_EQ(res, Error::Ok);
98   std::vector<std::string> expected = {"<|begin_of_text|>", "hello", " world"};
99   std::vector<uint64_t> tokens = {128000, 15339, 1917};
100   for (size_t i = 0; i < tokens.size(); i++) {
101     Result<std::string> out = tokenizer_->decode(0, tokens[i]);
102     EXPECT_EQ(out.error(), Error::Ok);
103     EXPECT_EQ(out.get(), expected[i]);
104   }
105 }
106 
TEST_F(TiktokenExtensionTest,TokenizerDecodeOutOfRangeFails)107 TEST_F(TiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) {
108   Error res = tokenizer_->load(modelPath_.c_str());
109   EXPECT_EQ(res, Error::Ok);
110   // The vocab size is 128256, addes 256 just so the token is out of vocab
111   // range.
112   Result<std::string> out = tokenizer_->decode(0, 128256 + 256);
113   EXPECT_EQ(out.error(), Error::NotSupported);
114 }
115 
TEST_F(TiktokenExtensionTest,ConstructionWithInvalidBOSIndex)116 TEST_F(TiktokenExtensionTest, ConstructionWithInvalidBOSIndex) {
117   // gtest death test doesn't work on iOS:
118   // https://github.com/google/googletest/issues/2834
119 #if !GTEST_OS_IOS
120   EXPECT_EXIT(
121       std::make_unique<Tiktoken>(
122           std::make_unique<std::vector<std::string>>(
123               std::vector<std::string>{"<|end_of_text|>"}),
124           1,
125           0),
126       ::testing::KilledBySignal(SIGABRT),
127       "");
128 #endif
129 }
130 
TEST_F(TiktokenExtensionTest,ConstructionWithInvalidEOSIndex)131 TEST_F(TiktokenExtensionTest, ConstructionWithInvalidEOSIndex) {
132   // gtest death test doesn't work on iOS:
133   // https://github.com/google/googletest/issues/2834
134 #if !GTEST_OS_IOS
135   EXPECT_EXIT(
136       std::make_unique<Tiktoken>(
137           std::make_unique<std::vector<std::string>>(
138               std::vector<std::string>{"<|begin_of_text|>"}),
139           0,
140           1),
141       ::testing::KilledBySignal(SIGABRT),
142       "");
143 #endif
144 }
145 
TEST_F(TiktokenExtensionTest,LoadWithInvalidPath)146 TEST_F(TiktokenExtensionTest, LoadWithInvalidPath) {
147   auto invalidModelPath =
148       std::getenv("RESOURCES_PATH") + std::string("/nonexistent.model");
149 
150   Error res = tokenizer_->load(invalidModelPath.c_str());
151   EXPECT_EQ(res, Error::InvalidArgument);
152 }
153 
TEST_F(TiktokenExtensionTest,LoadTiktokenFileWithInvalidRank)154 TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidRank) {
155   auto invalidModelPath = std::getenv("RESOURCES_PATH") +
156       std::string("/test_tiktoken_invalid_rank.model");
157 
158   Error res = tokenizer_->load(invalidModelPath.c_str());
159 
160   EXPECT_EQ(res, Error::InvalidArgument);
161 }
162 
TEST_F(TiktokenExtensionTest,LoadTiktokenFileWithInvalidBase64)163 TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidBase64) {
164   auto invalidModelPath = std::getenv("RESOURCES_PATH") +
165       std::string("/test_tiktoken_invalid_base64.model");
166 
167   Error res = tokenizer_->load(invalidModelPath.c_str());
168 
169   EXPECT_EQ(res, Error::InvalidArgument);
170 }
171 
TEST_F(TiktokenExtensionTest,LoadTiktokenFileWithNoSpace)172 TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithNoSpace) {
173   auto invalidModelPath = std::getenv("RESOURCES_PATH") +
174       std::string("/test_tiktoken_no_space.model");
175 
176   Error res = tokenizer_->load(invalidModelPath.c_str());
177 
178   EXPECT_EQ(res, Error::InvalidArgument);
179 }
180 
TEST_F(TiktokenExtensionTest,LoadTiktokenFileWithBPEFile)181 TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithBPEFile) {
182   auto invalidModelPath =
183       std::getenv("RESOURCES_PATH") + std::string("/test_bpe_tokenizer.bin");
184 
185   Error res = tokenizer_->load(invalidModelPath.c_str());
186 
187   EXPECT_EQ(res, Error::InvalidArgument);
188 }
189