xref: /aosp_15_r20/external/executorch/extension/llm/tokenizer/test/test_tokenizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8import struct
9import tempfile
10import unittest
11from unittest.mock import patch
12
13from executorch.extension.llm.tokenizer.tokenizer import Tokenizer
14
15
16class TestTokenizer(unittest.TestCase):
17    @patch("executorch.extension.llm.tokenizer.tokenizer.SentencePieceProcessor")
18    def test_export(self, mock_sp):
19        # Set up the mock SentencePieceProcessor
20        mock_sp.return_value.vocab_size.return_value = 0
21        mock_sp.return_value.bos_id.return_value = 0
22        mock_sp.return_value.eos_id.return_value = 0
23        mock_sp.return_value.get_piece_size.return_value = 0
24        # Create a temporary file
25        with tempfile.NamedTemporaryFile(delete=True) as temp:
26            # Initialize the tokenizer with the temporary file as the model
27            tokenizer = Tokenizer(temp.name)
28            # Export the tokenizer to another temporary file
29            with tempfile.NamedTemporaryFile(delete=True) as output:
30                tokenizer.export(output.name)
31                # Open the output file in binary mode and read the first 16 bytes
32                with open(output.name, "rb") as f:
33                    data = f.read(16)
34                # Unpack the data as 4 integers
35                vocab_size, bos_id, eos_id, max_token_length = struct.unpack(
36                    "IIII", data
37                )
38                # Check that the integers match the properties of the tokenizer
39                self.assertEqual(vocab_size, 0)
40                self.assertEqual(bos_id, 0)
41                self.assertEqual(eos_id, 0)
42                # Check that the max token length is correct
43                self.assertEqual(max_token_length, 0)
44