1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #include "llm_helper/include/mask_builder.h"
10 
11 #include <executorch/runtime/platform/assert.h>
12 #include <executorch/runtime/platform/log.h>
13 
14 namespace example {
15 namespace llm_helper {
16 
17 // Define mask values for different types
18 template <typename T>
19 struct MaskVal;
20 
21 #define __DECL_MASK__(TYPE, TRUE_VAL, FALSE_VAL) \
22   template <>                                    \
23   struct MaskVal<TYPE> {                         \
24     static constexpr TYPE kTrue = TRUE_VAL;      \
25     static constexpr TYPE kFalse = FALSE_VAL;    \
26   };
27 
__DECL_MASK__(bool,true,false)28 __DECL_MASK__(bool, true, false)
29 __DECL_MASK__(int16_t, 0, -32768)
30 __DECL_MASK__(__fp16, 0, -100)
31 __DECL_MASK__(float, 0, -100)
32 #undef __DECL_MASK__
33 
34 MaskBuilder::MaskBuilder(
35     void* maskBuffer,
36     const size_t maskSizeBytes,
37     const LLMType maskType,
38     const size_t cacheLength)
39     : mMaskBuffer(maskBuffer),
40       mMaskSizeBytes(maskSizeBytes),
41       kMaskType(maskType),
42       kMaskTypeSize(getLLMTypeSize(maskType)),
43       kCacheLength(cacheLength) {}
44 
~MaskBuilder()45 MaskBuilder::~MaskBuilder() {}
46 
updateMaskSize(const size_t sizeBytes)47 void MaskBuilder::updateMaskSize(const size_t sizeBytes) {
48   mMaskSizeBytes = sizeBytes;
49 }
50 
markMaskDirty()51 void MaskBuilder::markMaskDirty() {
52   mIsMaskUpdatable = false;
53 }
54 
55 template <typename MaskType>
buildMask(const size_t tokenBatchSize,const size_t numSeenToken)56 void MaskBuilder::buildMask(
57     const size_t tokenBatchSize,
58     const size_t numSeenToken) {
59   constexpr auto maskTrue = MaskVal<MaskType>::kTrue;
60   constexpr auto maskFalse = MaskVal<MaskType>::kFalse;
61   const size_t maskLength = kCacheLength + tokenBatchSize;
62 
63   // The mask is a combination (concat) of input cache mask and attention mask
64   const size_t startTrueIdx =
65       kCacheLength - std::min(kCacheLength, numSeenToken);
66 
67   const size_t rowSize = mMaskSizeBytes / tokenBatchSize / kMaskTypeSize;
68 
69   const size_t expectedMaskSizeBytes =
70       tokenBatchSize * maskLength * kMaskTypeSize;
71   // Use '<' instead of '!=' because mMaskSizeBytes may be padded by compiler to
72   // fit HW
73   if (mMaskSizeBytes < expectedMaskSizeBytes) {
74     ET_LOG(
75         Info,
76         "Warn: Model input mask size (%zu) < mask size to be built (%zu). "
77         "Please ensure your model options are set correctly.",
78         mMaskSizeBytes,
79         expectedMaskSizeBytes);
80   }
81 
82   // There are tokenBatchSize number of rows
83   for (size_t inTokIdx = 0; inTokIdx < tokenBatchSize; inTokIdx++) {
84     const auto& rowIdx = inTokIdx; // For clarity
85     auto curMaskBuffer =
86         reinterpret_cast<MaskType*>(mMaskBuffer) + rowIdx * rowSize;
87     size_t i = 0; // Buffer write index
88 
89     // Set the (rectangle) input cache mask
90     while (i < startTrueIdx)
91       curMaskBuffer[i++] = maskFalse;
92     while (i < kCacheLength)
93       curMaskBuffer[i++] = maskTrue;
94 
95     // Set the (triangle) attention mask
96     const size_t attnTrueCount = inTokIdx + 1;
97     for (size_t counter = 0; counter < attnTrueCount; counter++) {
98       curMaskBuffer[i++] = maskTrue;
99     }
100     // Fill the remaining with False
101     while (i < maskLength)
102       curMaskBuffer[i++] = maskFalse;
103   }
104 
105   // Modify mask for padding if needed. Mask is not updatable if modified for
106   // padding.
107   mIsMaskUpdatable = !adjustMaskForPadding<MaskType>(tokenBatchSize);
108 }
109 
110 template <typename MaskType>
updateMask(const size_t tokenBatchSize,const size_t numSeenToken,const size_t length)111 void MaskBuilder::updateMask(
112     const size_t tokenBatchSize,
113     const size_t numSeenToken,
114     const size_t length) {
115   constexpr auto maskTrue = MaskVal<MaskType>::kTrue;
116 
117   if (!mIsMaskUpdatable) {
118     buildMask<MaskType>(tokenBatchSize, numSeenToken);
119     return;
120   }
121 
122   // Only set True for seen token
123   const size_t trueCount = std::min(length, numSeenToken);
124   if (!trueCount) {
125     // Modify mask for padding if needed. Mask is not updatable if modified for
126     // padding.
127     mIsMaskUpdatable = !adjustMaskForPadding<MaskType>(tokenBatchSize);
128     return;
129   }
130 
131   // The mask is a combination (concat) of input cache mask and attention mask
132   auto maskBuffer = reinterpret_cast<MaskType*>(mMaskBuffer);
133 
134   const size_t rowSize = mMaskSizeBytes / tokenBatchSize / kMaskTypeSize;
135 
136   // Only modify the left rectangle part
137   const size_t startTrueOffset =
138       kCacheLength - std::min(kCacheLength, numSeenToken);
139 
140   for (size_t inTokIdx = 0; inTokIdx < tokenBatchSize; inTokIdx++) {
141     const auto& rowIdx = inTokIdx; // For clarity
142     auto curMaskBuffer = maskBuffer + rowIdx * rowSize + startTrueOffset;
143     std::fill(curMaskBuffer, curMaskBuffer + trueCount, maskTrue);
144   }
145   // Modify mask for padding if needed. Mask is not updatable if modified for
146   // padding.
147   mIsMaskUpdatable = !adjustMaskForPadding<MaskType>(tokenBatchSize);
148 }
149 
buildMask(const size_t tokenBatchSize,const size_t numSeenToken)150 void MaskBuilder::buildMask(
151     const size_t tokenBatchSize,
152     const size_t numSeenToken) {
153   switch (kMaskType) {
154     case LLMType::INT16:
155       buildMask<int16_t>(tokenBatchSize, numSeenToken);
156       return;
157     case LLMType::FP16:
158       buildMask<__fp16>(tokenBatchSize, numSeenToken);
159       return;
160     case LLMType::FP32:
161       buildMask<float>(tokenBatchSize, numSeenToken);
162       return;
163     default:
164       break;
165   }
166   ET_LOG(
167       Fatal,
168       "Attempting to build mask with type %s. Supported types are INT16, FP16, FP32.",
169       getLLMTypeName(kMaskType));
170 }
171 
updateMask(const size_t tokenBatchSize,const size_t numSeenToken,const size_t length)172 void MaskBuilder::updateMask(
173     const size_t tokenBatchSize,
174     const size_t numSeenToken,
175     const size_t length) {
176   switch (kMaskType) {
177     case LLMType::INT16:
178       updateMask<int16_t>(tokenBatchSize, numSeenToken, length);
179       return;
180     case LLMType::FP16:
181       updateMask<__fp16>(tokenBatchSize, numSeenToken, length);
182       return;
183     case LLMType::FP32:
184       updateMask<float>(tokenBatchSize, numSeenToken, length);
185       return;
186     default:
187       break;
188   }
189   ET_LOG(
190       Fatal,
191       "Attempting to update with an unsupported mask type. "
192       "Supported types are INT16, FP16, FP32.");
193 }
194 
notifyLeftPadding(const size_t padLength)195 void MaskBuilder::notifyLeftPadding(const size_t padLength) {
196   ET_CHECK_MSG(
197       mRightPadLength == 0,
198       "Attempting to set left pad after right pad has been set.");
199   if (mLeftPadLength > 0) {
200     ET_LOG(
201         Info,
202         "Warn: Calling notifyLeftPadding() multiple times before building/updating mask.");
203   }
204   mLeftPadLength = padLength;
205 }
206 
notifyRightPadding(const size_t padLength)207 void MaskBuilder::notifyRightPadding(const size_t padLength) {
208   ET_CHECK_MSG(
209       mLeftPadLength == 0,
210       "Attempting to set right pad after left pad has been set.");
211   if (mRightPadLength > 0) {
212     ET_LOG(
213         Info,
214         "Warn: Calling notifyLeftPadding() multiple times before building/updating mask.");
215   }
216   mRightPadLength = padLength;
217 }
218 
219 template <typename MaskType>
adjustMaskForPadding(const size_t tokenBatchSize)220 bool MaskBuilder::adjustMaskForPadding(const size_t tokenBatchSize) {
221   if (mLeftPadLength + mRightPadLength == 0) {
222     return false; // No need to modify mask since no padding
223   }
224   ET_DCHECK_MSG(
225       mLeftPadLength == 0 || mRightPadLength == 0,
226       "Only allow setting either left or right pad");
227   constexpr auto maskFalse = MaskVal<MaskType>::kFalse;
228   const size_t maskLength = kCacheLength + tokenBatchSize;
229 
230   // The mask is a combination (concat) of input cache mask and attention mask
231   auto maskBuffer = reinterpret_cast<MaskType*>(mMaskBuffer);
232 
233   const size_t rowSize = mMaskSizeBytes / tokenBatchSize / kMaskTypeSize;
234 
235   if (mLeftPadLength > 0) {
236     // Mask the padded rows
237     for (size_t inTokIdx = 0; inTokIdx < mLeftPadLength; inTokIdx++) {
238       auto curMaskBuffer = maskBuffer + inTokIdx * rowSize;
239       std::fill(curMaskBuffer, curMaskBuffer + maskLength, maskFalse);
240     }
241     // Mask the padded attention region
242     for (size_t inTokIdx = mLeftPadLength; inTokIdx < tokenBatchSize;
243          inTokIdx++) {
244       auto curMaskBuffer = maskBuffer + inTokIdx * rowSize + kCacheLength;
245       // Anything from inTokIdx + 1 onwards is already False, so can skip them.
246       const size_t maskPadCount = std::min(mLeftPadLength, inTokIdx + 1);
247       std::fill(curMaskBuffer, curMaskBuffer + maskPadCount, maskFalse);
248     }
249     mLeftPadLength = 0; // Reset pad length
250   } else if (mRightPadLength > 0) {
251     // Mask the padded rows
252     const auto startIdx = tokenBatchSize - mRightPadLength;
253     for (size_t inTokIdx = startIdx; inTokIdx < tokenBatchSize; inTokIdx++) {
254       auto curMaskBuffer = maskBuffer + inTokIdx * rowSize;
255       std::fill(curMaskBuffer, curMaskBuffer + maskLength, maskFalse);
256     }
257     mRightPadLength = 0; // Reset pad length
258   }
259   return true; // Mask is modified for padding
260 }
261 
262 } // namespace llm_helper
263 } // namespace example
264