1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #pragma once
10 
11 #include "llm_types.h"
12 
13 #include <string>
14 #include <vector>
15 
16 namespace example {
17 
18 class FileMemMapper;
19 
20 namespace llm_helper {
21 
22 class TokenEmbeddingLut {
23  public:
24   TokenEmbeddingLut(
25       const std::string& tokenEmbLutPath,
26       const LLMType tokenEmbLutType,
27       const size_t hiddenSize);
28 
29   ~TokenEmbeddingLut();
30 
31   void setOutput(void* buffer, const size_t size);
32 
33   void lookupEmbedding(const std::vector<uint64_t>& tokens);
34 
35  private:
36   // Source lookup table
37   uint8_t* mLutBuffer = nullptr;
38   const LLMType kTokenEmbLutType;
39   const size_t kTokenEmbLutTypeSize;
40   const size_t kHiddenSize;
41   const size_t kLutRowSizeBytes;
42   size_t mVocabSize;
43 
44   // Output write buffer
45   uint8_t* mOutputBuffer = nullptr;
46   size_t mOutputBufferSize = 0;
47 
48   std::unique_ptr<FileMemMapper> mMemMappedEmbFile;
49 };
50 
51 } // namespace llm_helper
52 } // namespace example
53