1 /*
2 * Copyright (c) 2024 MediaTek Inc.
3 *
4 * Licensed under the BSD License (the "License"); you may not use this file
5 * except in compliance with the License. See the license file in the root
6 * directory of this source tree for more details.
7 */
8
9 #include "llm_helper/include/mask_builder.h"
10
11 #include <executorch/runtime/platform/assert.h>
12 #include <executorch/runtime/platform/log.h>
13
14 namespace example {
15 namespace llm_helper {
16
17 // Define mask values for different types
18 template <typename T>
19 struct MaskVal;
20
21 #define __DECL_MASK__(TYPE, TRUE_VAL, FALSE_VAL) \
22 template <> \
23 struct MaskVal<TYPE> { \
24 static constexpr TYPE kTrue = TRUE_VAL; \
25 static constexpr TYPE kFalse = FALSE_VAL; \
26 };
27
__DECL_MASK__(bool,true,false)28 __DECL_MASK__(bool, true, false)
29 __DECL_MASK__(int16_t, 0, -32768)
30 __DECL_MASK__(__fp16, 0, -100)
31 __DECL_MASK__(float, 0, -100)
32 #undef __DECL_MASK__
33
34 MaskBuilder::MaskBuilder(
35 void* maskBuffer,
36 const size_t maskSizeBytes,
37 const LLMType maskType,
38 const size_t cacheLength)
39 : mMaskBuffer(maskBuffer),
40 mMaskSizeBytes(maskSizeBytes),
41 kMaskType(maskType),
42 kMaskTypeSize(getLLMTypeSize(maskType)),
43 kCacheLength(cacheLength) {}
44
~MaskBuilder()45 MaskBuilder::~MaskBuilder() {}
46
updateMaskSize(const size_t sizeBytes)47 void MaskBuilder::updateMaskSize(const size_t sizeBytes) {
48 mMaskSizeBytes = sizeBytes;
49 }
50
markMaskDirty()51 void MaskBuilder::markMaskDirty() {
52 mIsMaskUpdatable = false;
53 }
54
55 template <typename MaskType>
buildMask(const size_t tokenBatchSize,const size_t numSeenToken)56 void MaskBuilder::buildMask(
57 const size_t tokenBatchSize,
58 const size_t numSeenToken) {
59 constexpr auto maskTrue = MaskVal<MaskType>::kTrue;
60 constexpr auto maskFalse = MaskVal<MaskType>::kFalse;
61 const size_t maskLength = kCacheLength + tokenBatchSize;
62
63 // The mask is a combination (concat) of input cache mask and attention mask
64 const size_t startTrueIdx =
65 kCacheLength - std::min(kCacheLength, numSeenToken);
66
67 const size_t rowSize = mMaskSizeBytes / tokenBatchSize / kMaskTypeSize;
68
69 const size_t expectedMaskSizeBytes =
70 tokenBatchSize * maskLength * kMaskTypeSize;
71 // Use '<' instead of '!=' because mMaskSizeBytes may be padded by compiler to
72 // fit HW
73 if (mMaskSizeBytes < expectedMaskSizeBytes) {
74 ET_LOG(
75 Info,
76 "Warn: Model input mask size (%zu) < mask size to be built (%zu). "
77 "Please ensure your model options are set correctly.",
78 mMaskSizeBytes,
79 expectedMaskSizeBytes);
80 }
81
82 // There are tokenBatchSize number of rows
83 for (size_t inTokIdx = 0; inTokIdx < tokenBatchSize; inTokIdx++) {
84 const auto& rowIdx = inTokIdx; // For clarity
85 auto curMaskBuffer =
86 reinterpret_cast<MaskType*>(mMaskBuffer) + rowIdx * rowSize;
87 size_t i = 0; // Buffer write index
88
89 // Set the (rectangle) input cache mask
90 while (i < startTrueIdx)
91 curMaskBuffer[i++] = maskFalse;
92 while (i < kCacheLength)
93 curMaskBuffer[i++] = maskTrue;
94
95 // Set the (triangle) attention mask
96 const size_t attnTrueCount = inTokIdx + 1;
97 for (size_t counter = 0; counter < attnTrueCount; counter++) {
98 curMaskBuffer[i++] = maskTrue;
99 }
100 // Fill the remaining with False
101 while (i < maskLength)
102 curMaskBuffer[i++] = maskFalse;
103 }
104
105 // Modify mask for padding if needed. Mask is not updatable if modified for
106 // padding.
107 mIsMaskUpdatable = !adjustMaskForPadding<MaskType>(tokenBatchSize);
108 }
109
110 template <typename MaskType>
updateMask(const size_t tokenBatchSize,const size_t numSeenToken,const size_t length)111 void MaskBuilder::updateMask(
112 const size_t tokenBatchSize,
113 const size_t numSeenToken,
114 const size_t length) {
115 constexpr auto maskTrue = MaskVal<MaskType>::kTrue;
116
117 if (!mIsMaskUpdatable) {
118 buildMask<MaskType>(tokenBatchSize, numSeenToken);
119 return;
120 }
121
122 // Only set True for seen token
123 const size_t trueCount = std::min(length, numSeenToken);
124 if (!trueCount) {
125 // Modify mask for padding if needed. Mask is not updatable if modified for
126 // padding.
127 mIsMaskUpdatable = !adjustMaskForPadding<MaskType>(tokenBatchSize);
128 return;
129 }
130
131 // The mask is a combination (concat) of input cache mask and attention mask
132 auto maskBuffer = reinterpret_cast<MaskType*>(mMaskBuffer);
133
134 const size_t rowSize = mMaskSizeBytes / tokenBatchSize / kMaskTypeSize;
135
136 // Only modify the left rectangle part
137 const size_t startTrueOffset =
138 kCacheLength - std::min(kCacheLength, numSeenToken);
139
140 for (size_t inTokIdx = 0; inTokIdx < tokenBatchSize; inTokIdx++) {
141 const auto& rowIdx = inTokIdx; // For clarity
142 auto curMaskBuffer = maskBuffer + rowIdx * rowSize + startTrueOffset;
143 std::fill(curMaskBuffer, curMaskBuffer + trueCount, maskTrue);
144 }
145 // Modify mask for padding if needed. Mask is not updatable if modified for
146 // padding.
147 mIsMaskUpdatable = !adjustMaskForPadding<MaskType>(tokenBatchSize);
148 }
149
buildMask(const size_t tokenBatchSize,const size_t numSeenToken)150 void MaskBuilder::buildMask(
151 const size_t tokenBatchSize,
152 const size_t numSeenToken) {
153 switch (kMaskType) {
154 case LLMType::INT16:
155 buildMask<int16_t>(tokenBatchSize, numSeenToken);
156 return;
157 case LLMType::FP16:
158 buildMask<__fp16>(tokenBatchSize, numSeenToken);
159 return;
160 case LLMType::FP32:
161 buildMask<float>(tokenBatchSize, numSeenToken);
162 return;
163 default:
164 break;
165 }
166 ET_LOG(
167 Fatal,
168 "Attempting to build mask with type %s. Supported types are INT16, FP16, FP32.",
169 getLLMTypeName(kMaskType));
170 }
171
updateMask(const size_t tokenBatchSize,const size_t numSeenToken,const size_t length)172 void MaskBuilder::updateMask(
173 const size_t tokenBatchSize,
174 const size_t numSeenToken,
175 const size_t length) {
176 switch (kMaskType) {
177 case LLMType::INT16:
178 updateMask<int16_t>(tokenBatchSize, numSeenToken, length);
179 return;
180 case LLMType::FP16:
181 updateMask<__fp16>(tokenBatchSize, numSeenToken, length);
182 return;
183 case LLMType::FP32:
184 updateMask<float>(tokenBatchSize, numSeenToken, length);
185 return;
186 default:
187 break;
188 }
189 ET_LOG(
190 Fatal,
191 "Attempting to update with an unsupported mask type. "
192 "Supported types are INT16, FP16, FP32.");
193 }
194
notifyLeftPadding(const size_t padLength)195 void MaskBuilder::notifyLeftPadding(const size_t padLength) {
196 ET_CHECK_MSG(
197 mRightPadLength == 0,
198 "Attempting to set left pad after right pad has been set.");
199 if (mLeftPadLength > 0) {
200 ET_LOG(
201 Info,
202 "Warn: Calling notifyLeftPadding() multiple times before building/updating mask.");
203 }
204 mLeftPadLength = padLength;
205 }
206
notifyRightPadding(const size_t padLength)207 void MaskBuilder::notifyRightPadding(const size_t padLength) {
208 ET_CHECK_MSG(
209 mLeftPadLength == 0,
210 "Attempting to set right pad after left pad has been set.");
211 if (mRightPadLength > 0) {
212 ET_LOG(
213 Info,
214 "Warn: Calling notifyLeftPadding() multiple times before building/updating mask.");
215 }
216 mRightPadLength = padLength;
217 }
218
219 template <typename MaskType>
adjustMaskForPadding(const size_t tokenBatchSize)220 bool MaskBuilder::adjustMaskForPadding(const size_t tokenBatchSize) {
221 if (mLeftPadLength + mRightPadLength == 0) {
222 return false; // No need to modify mask since no padding
223 }
224 ET_DCHECK_MSG(
225 mLeftPadLength == 0 || mRightPadLength == 0,
226 "Only allow setting either left or right pad");
227 constexpr auto maskFalse = MaskVal<MaskType>::kFalse;
228 const size_t maskLength = kCacheLength + tokenBatchSize;
229
230 // The mask is a combination (concat) of input cache mask and attention mask
231 auto maskBuffer = reinterpret_cast<MaskType*>(mMaskBuffer);
232
233 const size_t rowSize = mMaskSizeBytes / tokenBatchSize / kMaskTypeSize;
234
235 if (mLeftPadLength > 0) {
236 // Mask the padded rows
237 for (size_t inTokIdx = 0; inTokIdx < mLeftPadLength; inTokIdx++) {
238 auto curMaskBuffer = maskBuffer + inTokIdx * rowSize;
239 std::fill(curMaskBuffer, curMaskBuffer + maskLength, maskFalse);
240 }
241 // Mask the padded attention region
242 for (size_t inTokIdx = mLeftPadLength; inTokIdx < tokenBatchSize;
243 inTokIdx++) {
244 auto curMaskBuffer = maskBuffer + inTokIdx * rowSize + kCacheLength;
245 // Anything from inTokIdx + 1 onwards is already False, so can skip them.
246 const size_t maskPadCount = std::min(mLeftPadLength, inTokIdx + 1);
247 std::fill(curMaskBuffer, curMaskBuffer + maskPadCount, maskFalse);
248 }
249 mLeftPadLength = 0; // Reset pad length
250 } else if (mRightPadLength > 0) {
251 // Mask the padded rows
252 const auto startIdx = tokenBatchSize - mRightPadLength;
253 for (size_t inTokIdx = startIdx; inTokIdx < tokenBatchSize; inTokIdx++) {
254 auto curMaskBuffer = maskBuffer + inTokIdx * rowSize;
255 std::fill(curMaskBuffer, curMaskBuffer + maskLength, maskFalse);
256 }
257 mRightPadLength = 0; // Reset pad length
258 }
259 return true; // Mask is modified for padding
260 }
261
262 } // namespace llm_helper
263 } // namespace example
264