1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8import struct 9import tempfile 10import unittest 11from unittest.mock import patch 12 13from executorch.extension.llm.tokenizer.tokenizer import Tokenizer 14 15 16class TestTokenizer(unittest.TestCase): 17 @patch("executorch.extension.llm.tokenizer.tokenizer.SentencePieceProcessor") 18 def test_export(self, mock_sp): 19 # Set up the mock SentencePieceProcessor 20 mock_sp.return_value.vocab_size.return_value = 0 21 mock_sp.return_value.bos_id.return_value = 0 22 mock_sp.return_value.eos_id.return_value = 0 23 mock_sp.return_value.get_piece_size.return_value = 0 24 # Create a temporary file 25 with tempfile.NamedTemporaryFile(delete=True) as temp: 26 # Initialize the tokenizer with the temporary file as the model 27 tokenizer = Tokenizer(temp.name) 28 # Export the tokenizer to another temporary file 29 with tempfile.NamedTemporaryFile(delete=True) as output: 30 tokenizer.export(output.name) 31 # Open the output file in binary mode and read the first 16 bytes 32 with open(output.name, "rb") as f: 33 data = f.read(16) 34 # Unpack the data as 4 integers 35 vocab_size, bos_id, eos_id, max_token_length = struct.unpack( 36 "IIII", data 37 ) 38 # Check that the integers match the properties of the tokenizer 39 self.assertEqual(vocab_size, 0) 40 self.assertEqual(bos_id, 0) 41 self.assertEqual(eos_id, 0) 42 # Check that the max token length is correct 43 self.assertEqual(max_token_length, 0) 44