1 /*
2 * Copyright (c) 2024 MediaTek Inc.
3 *
4 * Licensed under the BSD License (the "License"); you may not use this file
5 * except in compliance with the License. See the license file in the root
6 * directory of this source tree for more details.
7 */
8
9 #include "llm_helper/include/token_embedding.h"
10 #include "FileMemMapper.h"
11 #include "llm_helper/include/llm_types.h"
12
13 #include <executorch/runtime/platform/assert.h>
14 #include <executorch/runtime/platform/log.h>
15
16 #include <filesystem>
17 #include <fstream>
18 #include <string>
19
20 namespace fs = std::filesystem;
21
22 namespace example {
23 namespace llm_helper {
24
TokenEmbeddingLut(const std::string & tokenEmbLutPath,const LLMType tokenEmbLutType,const size_t hiddenSize)25 TokenEmbeddingLut::TokenEmbeddingLut(
26 const std::string& tokenEmbLutPath,
27 const LLMType tokenEmbLutType,
28 const size_t hiddenSize)
29 : kTokenEmbLutType(tokenEmbLutType),
30 kTokenEmbLutTypeSize(getLLMTypeSize(tokenEmbLutType)),
31 kHiddenSize(hiddenSize),
32 kLutRowSizeBytes(kHiddenSize * kTokenEmbLutTypeSize) {
33 ET_CHECK_MSG(
34 fs::exists(tokenEmbLutPath),
35 "Token embedding lookup table file not found: %s",
36 tokenEmbLutPath.c_str());
37
38 ET_LOG(
39 Debug,
40 "Loading token embedding lookup table: %s",
41 tokenEmbLutPath.c_str());
42
43 mMemMappedEmbFile = std::make_unique<FileMemMapper>(tokenEmbLutPath);
44 mLutBuffer = mMemMappedEmbFile->getAddr<uint8_t*>();
45 const size_t lutFileSize = mMemMappedEmbFile->getSize();
46
47 mVocabSize = lutFileSize / hiddenSize / kTokenEmbLutTypeSize;
48 ET_LOG(Debug, "TokenEmbeddingLut: Vocab size = %zu", mVocabSize);
49 }
50
~TokenEmbeddingLut()51 TokenEmbeddingLut::~TokenEmbeddingLut() {}
52
setOutput(void * buffer,const size_t size)53 void TokenEmbeddingLut::setOutput(void* buffer, const size_t size) {
54 mOutputBuffer = reinterpret_cast<uint8_t*>(buffer);
55 mOutputBufferSize = size;
56 }
57
lookupEmbedding(const std::vector<uint64_t> & tokens)58 void TokenEmbeddingLut::lookupEmbedding(const std::vector<uint64_t>& tokens) {
59 const auto numTokens = tokens.size();
60 const size_t requiredOutputSize =
61 numTokens * kHiddenSize * kTokenEmbLutTypeSize;
62 if (mOutputBufferSize < requiredOutputSize) {
63 ET_LOG(
64 Error,
65 "Token embedding buffer size (%zu) is insufficient to hold embedding for %zu tokens "
66 "(requires %zu).",
67 mOutputBufferSize,
68 numTokens,
69 requiredOutputSize);
70 return;
71 }
72 if (mOutputBuffer == nullptr) {
73 ET_LOG(
74 Error,
75 "TokenEmbeddingLut: Output is not yet set for embedding lookup.");
76 return;
77 }
78 size_t outputOffset = 0;
79 for (const auto token : tokens) {
80 // Copy one row from lookup table per token
81 ET_CHECK_MSG(
82 token < mVocabSize, "Token id exceeds embedding lookup table range.");
83 const auto& rowIdx = token;
84 const size_t lutOffset = rowIdx * kLutRowSizeBytes;
85 std::memcpy(
86 mOutputBuffer + outputOffset, mLutBuffer + lutOffset, kLutRowSizeBytes);
87 outputOffset += kLutRowSizeBytes;
88 }
89 return;
90 }
91
92 } // namespace llm_helper
93 } // namespace example
94