1 /*
2 * Copyright (c) 2024 MediaTek Inc.
3 *
4 * Licensed under the BSD License (the "License"); you may not use this file
5 * except in compliance with the License. See the license file in the root
6 * directory of this source tree for more details.
7 */
8
9 #include <numeric>
10 #include <string>
11 #include <unordered_map>
12 #include <vector>
13
14 #include <executorch/extension/data_loader/file_data_loader.h>
15 #include <executorch/extension/evalue_util/print_evalue.h>
16 #include <executorch/runtime/executor/method.h>
17 #include <executorch/runtime/executor/program.h>
18 #include <executorch/runtime/platform/log.h>
19 #include <executorch/runtime/platform/profiler.h>
20 #include <executorch/runtime/platform/runtime.h>
21
22 #include "LlamaConfig.h"
23 #include "LlamaModelChunk.h"
24 #include "llm_helper/include/llm_types.h"
25
26 #include "llm_helper/include/mask_builder.h"
27 #include "llm_helper/include/rotary_embedding.h"
28
29 namespace example {
30
getIndexRange(const size_t startIndex,const size_t count)31 inline std::vector<size_t> getIndexRange(
32 const size_t startIndex,
33 const size_t count) {
34 std::vector<size_t> indexes(count);
35 size_t counter = startIndex;
36 for (auto& idx : indexes) {
37 idx = counter++;
38 }
39 return indexes;
40 }
41
LlamaModelChunk(const ModelPathMap & modelPathMap,const LlamaModelOptions & modelOptions,const size_t initBatchSize,const size_t numCache,const size_t numRotEmbInputs,const RotaryEmbeddingMasterLut * rotEmbMasterLut)42 LlamaModelChunk::LlamaModelChunk(
43 const ModelPathMap& modelPathMap,
44 const LlamaModelOptions& modelOptions,
45 const size_t initBatchSize,
46 const size_t numCache,
47 const size_t numRotEmbInputs,
48 const RotaryEmbeddingMasterLut* rotEmbMasterLut)
49 : ModelChunk(modelPathMap, initBatchSize),
50 kMaxTokenLength(modelOptions.max_token_length),
51 kCacheLength(modelOptions.cache_size),
52 kMaskType(modelOptions.mask_type),
53 kRotEmbMasterLut(rotEmbMasterLut),
54 kCacheType(modelOptions.cache_type),
55 kCacheTypeSize(llm_helper::getLLMTypeSize(kCacheType)),
56 kMaskInputIndex(1),
57 kRotEmbInputIndexes(getIndexRange(2, numRotEmbInputs)),
58 kCacheInputIndexes(
59 getIndexRange(kRotEmbInputIndexes.back() + 1, numCache)),
60 kCacheOutputIndexes(getIndexRange(1, numCache)) {}
61
~LlamaModelChunk()62 LlamaModelChunk::~LlamaModelChunk() {}
63
GetExpectedInputCount() const64 size_t LlamaModelChunk::GetExpectedInputCount() const {
65 const size_t rotEmbInputCount = kRotEmbInputIndexes.size();
66 const size_t cacheInputCount = kCacheInputIndexes.size();
67 return 2 + rotEmbInputCount + cacheInputCount;
68 }
69
GetExpectedOutputCount() const70 size_t LlamaModelChunk::GetExpectedOutputCount() const {
71 const size_t cacheOutputCount = kCacheOutputIndexes.size();
72 return 1 + cacheOutputCount;
73 }
74
Initialize()75 void LlamaModelChunk::Initialize() {
76 LoadModels();
77 GetModelIoInfo();
78 CheckIoCount();
79 PrepareCacheIOs();
80 AllocateIoBuffers();
81 InitMaskBuilder();
82 InitCache();
83
84 SetBackendInputs();
85 SetBackendOutputs();
86 mIsInitialized = true;
87 }
88
CheckIoCount()89 void LlamaModelChunk::CheckIoCount() {
90 const auto& method = GetModelMethod();
91 const size_t modelInputCount = method.inputs_size();
92 const size_t modelOutputCount = method.outputs_size();
93 ET_CHECK_MSG(
94 modelInputCount == GetExpectedInputCount(),
95 "Number of inputs does not match (expected %zu but got %zu).",
96 GetExpectedInputCount(),
97 modelInputCount);
98 ET_CHECK_MSG(
99 modelOutputCount == GetExpectedOutputCount(),
100 "Number of outputs does not match (expected %zu but got %zu).",
101 GetExpectedOutputCount(),
102 modelOutputCount);
103 }
104
HotSwapModel(const size_t tokenBatchSize)105 bool LlamaModelChunk::HotSwapModel(const size_t tokenBatchSize) {
106 const auto status = ModelChunk::HotSwapModel(tokenBatchSize);
107
108 // Force rebuild mask because different batch size values will produce
109 // different mask shapes.
110 mMaskBuilder->markMaskDirty();
111
112 // Update mask size
113 const auto newMaskSizeBytes = mInputBufferInfos[kMaskInputIndex].nbytesUsed;
114 mMaskBuilder->updateMaskSize(newMaskSizeBytes);
115
116 return status;
117 }
118
Reset()119 void LlamaModelChunk::Reset() {
120 mCurrentPadSize = 0;
121 mCurrentTokenIndex = 0;
122 InitCache(); // Reset cache to zeros
123 }
124
SetLeftPadding(const size_t leftPadSize)125 void LlamaModelChunk::SetLeftPadding(const size_t leftPadSize) {
126 mCurrentPadSize = leftPadSize;
127 mPaddingMode = PaddingMode::LEFT;
128
129 // Notify mask builder about padding
130 mMaskBuilder->notifyLeftPadding(leftPadSize);
131 }
132
SetRightPadding(const size_t rightPadSize)133 void LlamaModelChunk::SetRightPadding(const size_t rightPadSize) {
134 mCurrentPadSize = rightPadSize;
135 mPaddingMode = PaddingMode::RIGHT;
136
137 // Notify mask builder about padding
138 mMaskBuilder->notifyRightPadding(rightPadSize);
139 }
140
GetLeftPadding() const141 size_t LlamaModelChunk::GetLeftPadding() const {
142 return (mPaddingMode == PaddingMode::LEFT) ? mCurrentPadSize : 0;
143 }
144
GetRightPadding() const145 size_t LlamaModelChunk::GetRightPadding() const {
146 return (mPaddingMode == PaddingMode::RIGHT) ? mCurrentPadSize : 0;
147 }
148
PaddingPostprocess()149 void LlamaModelChunk::PaddingPostprocess() {
150 if (mCurrentPadSize == 0) {
151 return;
152 }
153
154 if (mPaddingMode == PaddingMode::RIGHT) {
155 RightPaddingCachePostprocess();
156 } else if (mPaddingMode == PaddingMode::LEFT) {
157 LeftPaddingCachePostprocess();
158 }
159 }
160
LeftPaddingCachePostprocess()161 void LlamaModelChunk::LeftPaddingCachePostprocess() {
162 // NOTE: This part might not actually be needed
163
164 // Stride size is same across caches
165 const size_t strideSizeBytes = GetCacheStrideSize();
166 const size_t rowSize = kCacheLength * strideSizeBytes;
167
168 const size_t numRows = GetCacheNumRows();
169
170 const size_t offset = (kCacheLength - mTokenBatchSize) * strideSizeBytes;
171 const size_t zeroCount = mCurrentPadSize * strideSizeBytes;
172
173 // Fill padded sections with zeros
174 for (const auto cacheInputIdx : kCacheInputIndexes) {
175 auto cacheBuffer =
176 reinterpret_cast<char*>(mInputBufferInfos[cacheInputIdx].data);
177 for (size_t rowIdx = 0; rowIdx < numRows; rowIdx++) {
178 // cacheBufRow points at the start of row
179 auto cacheBufRow = cacheBuffer + rowIdx * rowSize;
180 std::memset(cacheBufRow + offset, 0, zeroCount);
181 }
182 }
183 }
184
RightPaddingCachePostprocess()185 void LlamaModelChunk::RightPaddingCachePostprocess() {
186 // NOTE: AdvanceTokenIndex() haven't been called for this inference step yet.
187 const size_t numSeenToken = mCurrentTokenIndex + mTokenBatchSize;
188 RollbackCache(mCurrentPadSize, numSeenToken);
189 }
190
RollbackCache(const size_t rollbackTokCount,const size_t numSeenToken)191 void LlamaModelChunk::RollbackCache(
192 const size_t rollbackTokCount,
193 const size_t numSeenToken) {
194 if (rollbackTokCount == 0) {
195 return; // do nothing
196 }
197
198 const size_t numSeenTokenAlive = std::min(numSeenToken, kCacheLength);
199 const size_t firstNonEmptyIdx = kCacheLength - numSeenTokenAlive;
200 const size_t preserveTokCount = (numSeenTokenAlive > rollbackTokCount)
201 ? numSeenTokenAlive - rollbackTokCount
202 : 0;
203
204 if (!preserveTokCount) {
205 // Clear cache to zeros
206 InitCache();
207 return;
208 }
209
210 const size_t strideSizeBytes = GetCacheStrideSize();
211 const size_t rowSize = kCacheLength * strideSizeBytes;
212 const size_t numRows = GetCacheNumRows();
213
214 // Shift right and truncate rollbackTokCount, then fill left with zeros
215 for (const auto cacheInputIdx : kCacheInputIndexes) {
216 auto cacheBuffer =
217 reinterpret_cast<char*>(mInputBufferInfos[cacheInputIdx].data);
218
219 for (size_t rowIdx = 0; rowIdx < numRows; rowIdx++) {
220 // Get the addr pointing to the start of row
221 auto cacheBufRow = cacheBuffer + rowIdx * rowSize;
222
223 // Move right for the section to be preserved
224 const size_t dstOffset =
225 strideSizeBytes * (firstNonEmptyIdx + rollbackTokCount);
226 const size_t srcOffset = strideSizeBytes * firstNonEmptyIdx;
227 const size_t preserveSize = strideSizeBytes * preserveTokCount;
228 ET_DCHECK(dstOffset + preserveSize <= rowSize);
229 std::memmove(
230 cacheBufRow + dstOffset, cacheBufRow + srcOffset, preserveSize);
231
232 // Then fill zeros to the section being moved out
233 const size_t offset = firstNonEmptyIdx * strideSizeBytes;
234 const size_t zeroCount = rollbackTokCount * strideSizeBytes;
235 std::memset(cacheBufRow + offset, 0, zeroCount);
236 }
237 }
238 }
239
UpdatePosEmbAndMask(const size_t numInputToken)240 void LlamaModelChunk::UpdatePosEmbAndMask(const size_t numInputToken) {
241 if (mCurrentTokenIndex + numInputToken > kMaxTokenLength) {
242 ET_LOG(
243 Fatal,
244 "Attempting to generate tokens exceeding the supported max token length (%zu)",
245 kMaxTokenLength);
246 }
247 if (mCurrentTokenIndex > 0 && GetLeftPadding() > 0) {
248 ET_LOG(Fatal, "Left-padding is only allowed in the first prompt pass.");
249 }
250 mMaskBuilder->updateMask(mTokenBatchSize, mCurrentTokenIndex, numInputToken);
251 SetPosEmbed(mCurrentTokenIndex);
252 }
253
AdvanceTokenIndex()254 void LlamaModelChunk::AdvanceTokenIndex() {
255 // Exclude padded tokens
256 const auto numValidInputToken = mTokenBatchSize - mCurrentPadSize;
257 mCurrentTokenIndex += numValidInputToken;
258
259 // Reset padding size
260 mCurrentPadSize = 0;
261 }
262
GetTokenIndex() const263 size_t LlamaModelChunk::GetTokenIndex() const {
264 return mCurrentTokenIndex;
265 }
266
Run()267 void LlamaModelChunk::Run() {
268 UpdatePosEmbAndMask(mTokenBatchSize);
269 ModelChunk::Run();
270 PaddingPostprocess();
271 AdvanceTokenIndex();
272 }
273
SetPosEmbed(const size_t tokenIndex)274 void LlamaModelChunk::SetPosEmbed(const size_t tokenIndex) {
275 if (tokenIndex >= kMaxTokenLength) {
276 ET_LOG(
277 Fatal,
278 "Attempting to set rotaty embedding using index exceeding the supported max token length "
279 "(%zu)",
280 kMaxTokenLength);
281 }
282
283 auto getRotEmbInputs = [&]() {
284 std::vector<void*> rotEmbInputs;
285 rotEmbInputs.reserve(kRotEmbInputIndexes.size());
286 for (const auto inputIdx : kRotEmbInputIndexes)
287 rotEmbInputs.push_back(mInputBufferInfos[inputIdx].data);
288 return rotEmbInputs;
289 };
290 kRotEmbMasterLut->setEmbed(
291 getRotEmbInputs(),
292 tokenIndex,
293 mTokenBatchSize,
294 GetLeftPadding(),
295 GetRightPadding());
296 }
297
PrepareCacheIOs()298 void LlamaModelChunk::PrepareCacheIOs() {
299 // Get cache shape
300 const auto method_meta = GetModelMethod().method_meta();
301 const auto firstInCacheIdx = kCacheInputIndexes.front();
302 mCacheShape = method_meta.input_tensor_meta(firstInCacheIdx)->sizes();
303
304 // Link cache IOs
305 const size_t numCaches = kCacheInputIndexes.size();
306 for (size_t i = 0; i < numCaches; i++) {
307 this->LinkModelIO(kCacheInputIndexes[i], kCacheOutputIndexes[i]);
308 }
309 }
310
GetCacheNumRows() const311 size_t LlamaModelChunk::GetCacheNumRows() const {
312 return std::reduce(
313 mCacheShape.begin(),
314 mCacheShape.begin() + kCacheLengthDim,
315 1,
316 std::multiplies<>());
317 }
318
GetCacheStrideSize() const319 size_t LlamaModelChunk::GetCacheStrideSize() const {
320 return std::reduce(
321 mCacheShape.begin() + kCacheLengthDim + 1,
322 mCacheShape.end(),
323 kCacheTypeSize,
324 std::multiplies<>());
325 }
326
InitMaskBuilder()327 void LlamaModelChunk::InitMaskBuilder() {
328 const auto& maskBufferInfo = mInputBufferInfos[kMaskInputIndex];
329 const auto maskBuffer = maskBufferInfo.data;
330 const auto maskSizeBytes = maskBufferInfo.nbytesUsed;
331 mMaskBuilder = std::make_unique<MaskBuilder>(
332 maskBuffer, maskSizeBytes, kMaskType, kCacheLength);
333 mMaskBuilder->buildMask(mTokenBatchSize, mCurrentTokenIndex);
334 }
335
InitCache()336 void LlamaModelChunk::InitCache() {
337 // Zero initialization
338 for (const auto cacheIdx : kCacheInputIndexes) {
339 const auto& inputCacheInfo = mInputBufferInfos[cacheIdx];
340 char* cacheBuffer = reinterpret_cast<char*>(inputCacheInfo.data);
341 const size_t cacheSizeBytes = inputCacheInfo.nbytes;
342 std::memset(cacheBuffer, 0, cacheSizeBytes);
343 }
344 }
345
346 } // namespace example
347