xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #include <numeric>
10 #include <string>
11 #include <unordered_map>
12 #include <vector>
13 
14 #include <executorch/extension/data_loader/file_data_loader.h>
15 #include <executorch/extension/evalue_util/print_evalue.h>
16 #include <executorch/runtime/executor/method.h>
17 #include <executorch/runtime/executor/program.h>
18 #include <executorch/runtime/platform/log.h>
19 #include <executorch/runtime/platform/profiler.h>
20 #include <executorch/runtime/platform/runtime.h>
21 
22 #include "LlamaConfig.h"
23 #include "LlamaModelChunk.h"
24 #include "llm_helper/include/llm_types.h"
25 
26 #include "llm_helper/include/mask_builder.h"
27 #include "llm_helper/include/rotary_embedding.h"
28 
29 namespace example {
30 
getIndexRange(const size_t startIndex,const size_t count)31 inline std::vector<size_t> getIndexRange(
32     const size_t startIndex,
33     const size_t count) {
34   std::vector<size_t> indexes(count);
35   size_t counter = startIndex;
36   for (auto& idx : indexes) {
37     idx = counter++;
38   }
39   return indexes;
40 }
41 
LlamaModelChunk(const ModelPathMap & modelPathMap,const LlamaModelOptions & modelOptions,const size_t initBatchSize,const size_t numCache,const size_t numRotEmbInputs,const RotaryEmbeddingMasterLut * rotEmbMasterLut)42 LlamaModelChunk::LlamaModelChunk(
43     const ModelPathMap& modelPathMap,
44     const LlamaModelOptions& modelOptions,
45     const size_t initBatchSize,
46     const size_t numCache,
47     const size_t numRotEmbInputs,
48     const RotaryEmbeddingMasterLut* rotEmbMasterLut)
49     : ModelChunk(modelPathMap, initBatchSize),
50       kMaxTokenLength(modelOptions.max_token_length),
51       kCacheLength(modelOptions.cache_size),
52       kMaskType(modelOptions.mask_type),
53       kRotEmbMasterLut(rotEmbMasterLut),
54       kCacheType(modelOptions.cache_type),
55       kCacheTypeSize(llm_helper::getLLMTypeSize(kCacheType)),
56       kMaskInputIndex(1),
57       kRotEmbInputIndexes(getIndexRange(2, numRotEmbInputs)),
58       kCacheInputIndexes(
59           getIndexRange(kRotEmbInputIndexes.back() + 1, numCache)),
60       kCacheOutputIndexes(getIndexRange(1, numCache)) {}
61 
~LlamaModelChunk()62 LlamaModelChunk::~LlamaModelChunk() {}
63 
GetExpectedInputCount() const64 size_t LlamaModelChunk::GetExpectedInputCount() const {
65   const size_t rotEmbInputCount = kRotEmbInputIndexes.size();
66   const size_t cacheInputCount = kCacheInputIndexes.size();
67   return 2 + rotEmbInputCount + cacheInputCount;
68 }
69 
GetExpectedOutputCount() const70 size_t LlamaModelChunk::GetExpectedOutputCount() const {
71   const size_t cacheOutputCount = kCacheOutputIndexes.size();
72   return 1 + cacheOutputCount;
73 }
74 
Initialize()75 void LlamaModelChunk::Initialize() {
76   LoadModels();
77   GetModelIoInfo();
78   CheckIoCount();
79   PrepareCacheIOs();
80   AllocateIoBuffers();
81   InitMaskBuilder();
82   InitCache();
83 
84   SetBackendInputs();
85   SetBackendOutputs();
86   mIsInitialized = true;
87 }
88 
CheckIoCount()89 void LlamaModelChunk::CheckIoCount() {
90   const auto& method = GetModelMethod();
91   const size_t modelInputCount = method.inputs_size();
92   const size_t modelOutputCount = method.outputs_size();
93   ET_CHECK_MSG(
94       modelInputCount == GetExpectedInputCount(),
95       "Number of inputs does not match (expected %zu but got %zu).",
96       GetExpectedInputCount(),
97       modelInputCount);
98   ET_CHECK_MSG(
99       modelOutputCount == GetExpectedOutputCount(),
100       "Number of outputs does not match (expected %zu but got %zu).",
101       GetExpectedOutputCount(),
102       modelOutputCount);
103 }
104 
HotSwapModel(const size_t tokenBatchSize)105 bool LlamaModelChunk::HotSwapModel(const size_t tokenBatchSize) {
106   const auto status = ModelChunk::HotSwapModel(tokenBatchSize);
107 
108   // Force rebuild mask because different batch size values will produce
109   // different mask shapes.
110   mMaskBuilder->markMaskDirty();
111 
112   // Update mask size
113   const auto newMaskSizeBytes = mInputBufferInfos[kMaskInputIndex].nbytesUsed;
114   mMaskBuilder->updateMaskSize(newMaskSizeBytes);
115 
116   return status;
117 }
118 
Reset()119 void LlamaModelChunk::Reset() {
120   mCurrentPadSize = 0;
121   mCurrentTokenIndex = 0;
122   InitCache(); // Reset cache to zeros
123 }
124 
SetLeftPadding(const size_t leftPadSize)125 void LlamaModelChunk::SetLeftPadding(const size_t leftPadSize) {
126   mCurrentPadSize = leftPadSize;
127   mPaddingMode = PaddingMode::LEFT;
128 
129   // Notify mask builder about padding
130   mMaskBuilder->notifyLeftPadding(leftPadSize);
131 }
132 
SetRightPadding(const size_t rightPadSize)133 void LlamaModelChunk::SetRightPadding(const size_t rightPadSize) {
134   mCurrentPadSize = rightPadSize;
135   mPaddingMode = PaddingMode::RIGHT;
136 
137   // Notify mask builder about padding
138   mMaskBuilder->notifyRightPadding(rightPadSize);
139 }
140 
GetLeftPadding() const141 size_t LlamaModelChunk::GetLeftPadding() const {
142   return (mPaddingMode == PaddingMode::LEFT) ? mCurrentPadSize : 0;
143 }
144 
GetRightPadding() const145 size_t LlamaModelChunk::GetRightPadding() const {
146   return (mPaddingMode == PaddingMode::RIGHT) ? mCurrentPadSize : 0;
147 }
148 
PaddingPostprocess()149 void LlamaModelChunk::PaddingPostprocess() {
150   if (mCurrentPadSize == 0) {
151     return;
152   }
153 
154   if (mPaddingMode == PaddingMode::RIGHT) {
155     RightPaddingCachePostprocess();
156   } else if (mPaddingMode == PaddingMode::LEFT) {
157     LeftPaddingCachePostprocess();
158   }
159 }
160 
LeftPaddingCachePostprocess()161 void LlamaModelChunk::LeftPaddingCachePostprocess() {
162   // NOTE: This part might not actually be needed
163 
164   // Stride size is same across caches
165   const size_t strideSizeBytes = GetCacheStrideSize();
166   const size_t rowSize = kCacheLength * strideSizeBytes;
167 
168   const size_t numRows = GetCacheNumRows();
169 
170   const size_t offset = (kCacheLength - mTokenBatchSize) * strideSizeBytes;
171   const size_t zeroCount = mCurrentPadSize * strideSizeBytes;
172 
173   // Fill padded sections with zeros
174   for (const auto cacheInputIdx : kCacheInputIndexes) {
175     auto cacheBuffer =
176         reinterpret_cast<char*>(mInputBufferInfos[cacheInputIdx].data);
177     for (size_t rowIdx = 0; rowIdx < numRows; rowIdx++) {
178       // cacheBufRow points at the start of row
179       auto cacheBufRow = cacheBuffer + rowIdx * rowSize;
180       std::memset(cacheBufRow + offset, 0, zeroCount);
181     }
182   }
183 }
184 
RightPaddingCachePostprocess()185 void LlamaModelChunk::RightPaddingCachePostprocess() {
186   // NOTE: AdvanceTokenIndex() haven't been called for this inference step yet.
187   const size_t numSeenToken = mCurrentTokenIndex + mTokenBatchSize;
188   RollbackCache(mCurrentPadSize, numSeenToken);
189 }
190 
RollbackCache(const size_t rollbackTokCount,const size_t numSeenToken)191 void LlamaModelChunk::RollbackCache(
192     const size_t rollbackTokCount,
193     const size_t numSeenToken) {
194   if (rollbackTokCount == 0) {
195     return; // do nothing
196   }
197 
198   const size_t numSeenTokenAlive = std::min(numSeenToken, kCacheLength);
199   const size_t firstNonEmptyIdx = kCacheLength - numSeenTokenAlive;
200   const size_t preserveTokCount = (numSeenTokenAlive > rollbackTokCount)
201       ? numSeenTokenAlive - rollbackTokCount
202       : 0;
203 
204   if (!preserveTokCount) {
205     // Clear cache to zeros
206     InitCache();
207     return;
208   }
209 
210   const size_t strideSizeBytes = GetCacheStrideSize();
211   const size_t rowSize = kCacheLength * strideSizeBytes;
212   const size_t numRows = GetCacheNumRows();
213 
214   // Shift right and truncate rollbackTokCount, then fill left with zeros
215   for (const auto cacheInputIdx : kCacheInputIndexes) {
216     auto cacheBuffer =
217         reinterpret_cast<char*>(mInputBufferInfos[cacheInputIdx].data);
218 
219     for (size_t rowIdx = 0; rowIdx < numRows; rowIdx++) {
220       // Get the addr pointing to the start of row
221       auto cacheBufRow = cacheBuffer + rowIdx * rowSize;
222 
223       // Move right for the section to be preserved
224       const size_t dstOffset =
225           strideSizeBytes * (firstNonEmptyIdx + rollbackTokCount);
226       const size_t srcOffset = strideSizeBytes * firstNonEmptyIdx;
227       const size_t preserveSize = strideSizeBytes * preserveTokCount;
228       ET_DCHECK(dstOffset + preserveSize <= rowSize);
229       std::memmove(
230           cacheBufRow + dstOffset, cacheBufRow + srcOffset, preserveSize);
231 
232       // Then fill zeros to the section being moved out
233       const size_t offset = firstNonEmptyIdx * strideSizeBytes;
234       const size_t zeroCount = rollbackTokCount * strideSizeBytes;
235       std::memset(cacheBufRow + offset, 0, zeroCount);
236     }
237   }
238 }
239 
UpdatePosEmbAndMask(const size_t numInputToken)240 void LlamaModelChunk::UpdatePosEmbAndMask(const size_t numInputToken) {
241   if (mCurrentTokenIndex + numInputToken > kMaxTokenLength) {
242     ET_LOG(
243         Fatal,
244         "Attempting to generate tokens exceeding the supported max token length (%zu)",
245         kMaxTokenLength);
246   }
247   if (mCurrentTokenIndex > 0 && GetLeftPadding() > 0) {
248     ET_LOG(Fatal, "Left-padding is only allowed in the first prompt pass.");
249   }
250   mMaskBuilder->updateMask(mTokenBatchSize, mCurrentTokenIndex, numInputToken);
251   SetPosEmbed(mCurrentTokenIndex);
252 }
253 
AdvanceTokenIndex()254 void LlamaModelChunk::AdvanceTokenIndex() {
255   // Exclude padded tokens
256   const auto numValidInputToken = mTokenBatchSize - mCurrentPadSize;
257   mCurrentTokenIndex += numValidInputToken;
258 
259   // Reset padding size
260   mCurrentPadSize = 0;
261 }
262 
GetTokenIndex() const263 size_t LlamaModelChunk::GetTokenIndex() const {
264   return mCurrentTokenIndex;
265 }
266 
Run()267 void LlamaModelChunk::Run() {
268   UpdatePosEmbAndMask(mTokenBatchSize);
269   ModelChunk::Run();
270   PaddingPostprocess();
271   AdvanceTokenIndex();
272 }
273 
SetPosEmbed(const size_t tokenIndex)274 void LlamaModelChunk::SetPosEmbed(const size_t tokenIndex) {
275   if (tokenIndex >= kMaxTokenLength) {
276     ET_LOG(
277         Fatal,
278         "Attempting to set rotaty embedding using index exceeding the supported max token length "
279         "(%zu)",
280         kMaxTokenLength);
281   }
282 
283   auto getRotEmbInputs = [&]() {
284     std::vector<void*> rotEmbInputs;
285     rotEmbInputs.reserve(kRotEmbInputIndexes.size());
286     for (const auto inputIdx : kRotEmbInputIndexes)
287       rotEmbInputs.push_back(mInputBufferInfos[inputIdx].data);
288     return rotEmbInputs;
289   };
290   kRotEmbMasterLut->setEmbed(
291       getRotEmbInputs(),
292       tokenIndex,
293       mTokenBatchSize,
294       GetLeftPadding(),
295       GetRightPadding());
296 }
297 
PrepareCacheIOs()298 void LlamaModelChunk::PrepareCacheIOs() {
299   // Get cache shape
300   const auto method_meta = GetModelMethod().method_meta();
301   const auto firstInCacheIdx = kCacheInputIndexes.front();
302   mCacheShape = method_meta.input_tensor_meta(firstInCacheIdx)->sizes();
303 
304   // Link cache IOs
305   const size_t numCaches = kCacheInputIndexes.size();
306   for (size_t i = 0; i < numCaches; i++) {
307     this->LinkModelIO(kCacheInputIndexes[i], kCacheOutputIndexes[i]);
308   }
309 }
310 
GetCacheNumRows() const311 size_t LlamaModelChunk::GetCacheNumRows() const {
312   return std::reduce(
313       mCacheShape.begin(),
314       mCacheShape.begin() + kCacheLengthDim,
315       1,
316       std::multiplies<>());
317 }
318 
GetCacheStrideSize() const319 size_t LlamaModelChunk::GetCacheStrideSize() const {
320   return std::reduce(
321       mCacheShape.begin() + kCacheLengthDim + 1,
322       mCacheShape.end(),
323       kCacheTypeSize,
324       std::multiplies<>());
325 }
326 
InitMaskBuilder()327 void LlamaModelChunk::InitMaskBuilder() {
328   const auto& maskBufferInfo = mInputBufferInfos[kMaskInputIndex];
329   const auto maskBuffer = maskBufferInfo.data;
330   const auto maskSizeBytes = maskBufferInfo.nbytesUsed;
331   mMaskBuilder = std::make_unique<MaskBuilder>(
332       maskBuffer, maskSizeBytes, kMaskType, kCacheLength);
333   mMaskBuilder->buildMask(mTokenBatchSize, mCurrentTokenIndex);
334 }
335 
InitCache()336 void LlamaModelChunk::InitCache() {
337   // Zero initialization
338   for (const auto cacheIdx : kCacheInputIndexes) {
339     const auto& inputCacheInfo = mInputBufferInfos[cacheIdx];
340     char* cacheBuffer = reinterpret_cast<char*>(inputCacheInfo.data);
341     const size_t cacheSizeBytes = inputCacheInfo.nbytes;
342     std::memset(cacheBuffer, 0, cacheSizeBytes);
343   }
344 }
345 
346 } // namespace example
347