1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #include "llm_helper/include/token_embedding.h"
10 #include "FileMemMapper.h"
11 #include "llm_helper/include/llm_types.h"
12 
13 #include <executorch/runtime/platform/assert.h>
14 #include <executorch/runtime/platform/log.h>
15 
16 #include <filesystem>
17 #include <fstream>
18 #include <string>
19 
20 namespace fs = std::filesystem;
21 
22 namespace example {
23 namespace llm_helper {
24 
TokenEmbeddingLut(const std::string & tokenEmbLutPath,const LLMType tokenEmbLutType,const size_t hiddenSize)25 TokenEmbeddingLut::TokenEmbeddingLut(
26     const std::string& tokenEmbLutPath,
27     const LLMType tokenEmbLutType,
28     const size_t hiddenSize)
29     : kTokenEmbLutType(tokenEmbLutType),
30       kTokenEmbLutTypeSize(getLLMTypeSize(tokenEmbLutType)),
31       kHiddenSize(hiddenSize),
32       kLutRowSizeBytes(kHiddenSize * kTokenEmbLutTypeSize) {
33   ET_CHECK_MSG(
34       fs::exists(tokenEmbLutPath),
35       "Token embedding lookup table file not found: %s",
36       tokenEmbLutPath.c_str());
37 
38   ET_LOG(
39       Debug,
40       "Loading token embedding lookup table: %s",
41       tokenEmbLutPath.c_str());
42 
43   mMemMappedEmbFile = std::make_unique<FileMemMapper>(tokenEmbLutPath);
44   mLutBuffer = mMemMappedEmbFile->getAddr<uint8_t*>();
45   const size_t lutFileSize = mMemMappedEmbFile->getSize();
46 
47   mVocabSize = lutFileSize / hiddenSize / kTokenEmbLutTypeSize;
48   ET_LOG(Debug, "TokenEmbeddingLut: Vocab size = %zu", mVocabSize);
49 }
50 
~TokenEmbeddingLut()51 TokenEmbeddingLut::~TokenEmbeddingLut() {}
52 
setOutput(void * buffer,const size_t size)53 void TokenEmbeddingLut::setOutput(void* buffer, const size_t size) {
54   mOutputBuffer = reinterpret_cast<uint8_t*>(buffer);
55   mOutputBufferSize = size;
56 }
57 
lookupEmbedding(const std::vector<uint64_t> & tokens)58 void TokenEmbeddingLut::lookupEmbedding(const std::vector<uint64_t>& tokens) {
59   const auto numTokens = tokens.size();
60   const size_t requiredOutputSize =
61       numTokens * kHiddenSize * kTokenEmbLutTypeSize;
62   if (mOutputBufferSize < requiredOutputSize) {
63     ET_LOG(
64         Error,
65         "Token embedding buffer size (%zu) is insufficient to hold embedding for %zu tokens "
66         "(requires %zu).",
67         mOutputBufferSize,
68         numTokens,
69         requiredOutputSize);
70     return;
71   }
72   if (mOutputBuffer == nullptr) {
73     ET_LOG(
74         Error,
75         "TokenEmbeddingLut: Output is not yet set for embedding lookup.");
76     return;
77   }
78   size_t outputOffset = 0;
79   for (const auto token : tokens) {
80     // Copy one row from lookup table per token
81     ET_CHECK_MSG(
82         token < mVocabSize, "Token id exceeds embedding lookup table range.");
83     const auto& rowIdx = token;
84     const size_t lutOffset = rowIdx * kLutRowSizeBytes;
85     std::memcpy(
86         mOutputBuffer + outputOffset, mLutBuffer + lutOffset, kLutRowSizeBytes);
87     outputOffset += kLutRowSizeBytes;
88   }
89   return;
90 }
91 
92 } // namespace llm_helper
93 } // namespace example
94