1 /*
2 * Copyright (c) 2024 MediaTek Inc.
3 *
4 * Licensed under the BSD License (the "License"); you may not use this file
5 * except in compliance with the License. See the license file in the root
6 * directory of this source tree for more details.
7 */
8
9 #include "llm_helper/include/rotary_embedding.h"
10 #include "llm_helper/include/llm_types.h"
11
12 #include <executorch/runtime/platform/assert.h>
13 #include <executorch/runtime/platform/log.h>
14
15 #include <cmath>
16 #include <fstream>
17 #include <type_traits>
18
19 namespace example {
20 namespace llm_helper {
21
RotaryEmbeddingMasterLut(const LLMType rotEmbType,const size_t length,const size_t headDim,const float rotBase,const float ntkScale)22 RotaryEmbeddingMasterLut::RotaryEmbeddingMasterLut(
23 const LLMType rotEmbType,
24 const size_t length,
25 const size_t headDim,
26 const float rotBase,
27 const float ntkScale)
28 : kType(rotEmbType),
29 kTypeSize(getLLMTypeSize(kType)),
30 kLength(length),
31 kHeadDim(headDim),
32 kRotBase(rotBase),
33 kNtkScale(ntkScale) {
34 // Shape: (length, 2*headDim), where 2 is sin & cos
35 mMasterLut = std::make_unique<char[]>(kLength * 2 * kHeadDim * kTypeSize);
36 }
37
load(const std::string & sinMasterPath,const std::string & cosMasterPath)38 void RotaryEmbeddingMasterLut::load(
39 const std::string& sinMasterPath,
40 const std::string& cosMasterPath) {
41 if (sinMasterPath.size() == 0 && cosMasterPath.size() == 0) {
42 generate();
43 return;
44 }
45
46 ET_LOG(
47 Debug,
48 "Begin loading rotary embedding lookup table from provided paths.");
49
50 std::ifstream fileCos(cosMasterPath, std::ios::binary);
51 std::ifstream fileSin(sinMasterPath, std::ios::binary);
52
53 // File paths checking
54 if (!fileCos) {
55 ET_LOG(
56 Info,
57 "Warn: Rotary embedding lookup table file not found: %s. "
58 "Will generate rotary embedding lookup table instead.",
59 cosMasterPath.c_str());
60 generate();
61 return;
62 }
63 if (!fileSin) {
64 ET_LOG(
65 Info,
66 "Warn: Rotary embedding lookup table file not found: %s. "
67 "Will generate rotary embedding lookup table instead.",
68 sinMasterPath.c_str());
69 generate();
70 return;
71 }
72
73 const auto rows = kLength;
74 const auto rowSize = 2 * kHeadDim * kTypeSize; // x2 for sin & cos
75 const size_t cosOffset = 0;
76 const size_t sinOffset =
77 rowSize / 2; // Halfway in row because each row is [<cos><sin>]
78 const auto readSize = kHeadDim * kTypeSize;
79
80 // Read lookup table files
81 for (size_t i = 0; i < rows; ++i) {
82 // Read cos then sin
83 fileCos.read(mMasterLut.get() + i * rowSize + cosOffset, readSize);
84 fileSin.read(mMasterLut.get() + i * rowSize + sinOffset, readSize);
85 }
86 mIsReady = true;
87 }
88
89 // For float and __fp16
90 template <typename RotEmbType>
generate()91 void RotaryEmbeddingMasterLut::generate() {
92 static_assert(
93 std::is_same<RotEmbType, float>() || std::is_same<RotEmbType, __fp16>(),
94 "Only int16/fp16/fp32 are supported for RotEmbType");
95 ET_LOG(Debug, "Generating floating rotary embedding lookup table");
96
97 const auto rowSize = kHeadDim * 2; // x2 for sin & cos
98 const size_t rotDim = kHeadDim;
99 const size_t rotDimHalf = rotDim / 2;
100
101 const float rotDimFp = static_cast<float>(kHeadDim);
102 const float base = (kNtkScale == 1.0f)
103 ? kRotBase
104 : std::powf(kRotBase * kNtkScale, rotDimFp / (rotDimFp - 2.0f));
105
106 for (int pos = 0; pos < kLength; pos++) { // row in lut
107 for (int dim = 0; dim < rotDimHalf; dim++) {
108 const float freq =
109 float(pos) / std::powf(base, float(dim * 2) / rotDimFp);
110 const RotEmbType embCos = static_cast<RotEmbType>(std::cos(freq));
111 const RotEmbType embSin = static_cast<RotEmbType>(std::sin(freq));
112
113 const auto& row = pos;
114 const auto& col = dim; // At most kHeadDim / 2
115 auto masterLutCurPtr =
116 reinterpret_cast<RotEmbType*>(mMasterLut.get()) + row * rowSize + col;
117
118 // Concat Cos then Sin, and duplicate each
119 // Each row looks like this:
120 // [<--cos--><--cos--><--sin--><--sin-->]
121 // | | | |
122 // 0 rotDimHalf | |
123 // rotDim |
124 // rotDim + rotDimHalf
125 masterLutCurPtr[0] = embCos;
126 masterLutCurPtr[rotDimHalf] = embCos;
127 masterLutCurPtr[rotDim] = embSin;
128 masterLutCurPtr[rotDim + rotDimHalf] = embSin;
129 }
130 }
131 mIsReady = true;
132 }
133
134 // NOTE: The difference between this and the Python script generated rotary
135 // embedding master lut is the rounding mechanism during quantization to INT16.
136 // Python's Numpy library uses round-to-even (banker's rounding) whereas the
137 // below C++ code uses round-to-nearest.
138 template <>
generate()139 void RotaryEmbeddingMasterLut::generate<int16_t>() {
140 ET_LOG(Debug, "Generating int16 rotary embedding lookup table");
141
142 const auto rowSize = kHeadDim * 2; // x2 for sin & cos
143 const size_t rotDim = kHeadDim;
144 const size_t rotDimHalf = rotDim / 2;
145
146 const float rotDimFp = static_cast<float>(kHeadDim);
147 const float base = (kNtkScale == 1.0f)
148 ? kRotBase
149 : std::powf(kRotBase * kNtkScale, rotDimFp / (rotDimFp - 2.0f));
150
151 // Minmax=(-1,1), so qscale = 1/32767
152 const float qscale = 0.000030518509447574615;
153
154 auto quantFP32ToINT16 = [&](const float fpval) -> int16_t {
155 const int qmin = -32768; // -2^(outBitwidth-1)
156 const int qmax = +32767; // 2^(outBitwidth-1)-1
157 const int quantized = std::round(fpval / qscale);
158 const int clamped = std::max(qmin, std::min(quantized, qmax));
159 return clamped;
160 };
161
162 for (int pos = 0; pos < kLength; pos++) { // row in lut
163 for (int dim = 0; dim < rotDimHalf; dim++) {
164 const float freq =
165 float(pos) / std::powf(base, float(dim * 2) / rotDimFp);
166 const int16_t embCos = quantFP32ToINT16(std::cos(freq));
167 const int16_t embSin = quantFP32ToINT16(std::sin(freq));
168
169 const auto& row = pos;
170 const auto& col = dim; // At most kHeadDim / 2
171 auto masterLutCurPtr =
172 reinterpret_cast<int16_t*>(mMasterLut.get()) + row * rowSize + col;
173
174 // Concat Cos then Sin, and duplicate each
175 // Each row looks like this:
176 // [<--cos--><--cos--><--sin--><--sin-->]
177 // | | | |
178 // 0 rotDimHalf | |
179 // rotDim |
180 // rotDim + rotDimHalf
181 masterLutCurPtr[0] = embCos;
182 masterLutCurPtr[rotDimHalf] = embCos;
183 masterLutCurPtr[rotDim] = embSin;
184 masterLutCurPtr[rotDim + rotDimHalf] = embSin;
185 }
186 }
187 mIsReady = true;
188 }
189
generate()190 void RotaryEmbeddingMasterLut::generate() {
191 switch (kType) {
192 case LLMType::INT16:
193 generate<int16_t>();
194 return;
195 case LLMType::FP16:
196 generate<__fp16>();
197 return;
198 case LLMType::FP32:
199 generate<float>();
200 return;
201 default:
202 break;
203 }
204 ET_LOG(
205 Fatal,
206 "Rotary embedding generator not implemented for %s",
207 getLLMTypeName(kType));
208 }
209
210 // RotaryEmbeddingMasterLut supports 1 or 2 rotary embedding inputs
setEmbed(std::vector<void * > rotEmbedBuffers,const size_t tokenIndex,const size_t tokenBatchSize,const size_t leftPadLength,const size_t rightPadLength) const211 void RotaryEmbeddingMasterLut::setEmbed(
212 std::vector<void*> rotEmbedBuffers,
213 const size_t tokenIndex,
214 const size_t tokenBatchSize,
215 const size_t leftPadLength,
216 const size_t rightPadLength) const {
217 const auto numRotEmbInputs = rotEmbedBuffers.size();
218 switch (numRotEmbInputs) {
219 case 1: {
220 const auto rotEmbInput = rotEmbedBuffers[0];
221 setEmbed(
222 rotEmbInput,
223 tokenIndex,
224 tokenBatchSize,
225 leftPadLength,
226 rightPadLength);
227 break;
228 }
229 case 2: {
230 const auto rotEmbCosInput = rotEmbedBuffers[0];
231 const auto rotEmbSinInput = rotEmbedBuffers[1];
232 setEmbed(
233 rotEmbCosInput,
234 rotEmbSinInput,
235 tokenIndex,
236 tokenBatchSize,
237 leftPadLength,
238 rightPadLength);
239 break;
240 }
241 default:
242 ET_LOG(
243 Fatal,
244 "RotaryEmbeddingMasterLut: Unsupported number of rotary embedding inputs (%zu).",
245 numRotEmbInputs);
246 }
247 }
248
setEmbed(void * rotEmbedBuffer,const size_t tokenIndex,const size_t tokenBatchSize,const size_t leftPadLength,const size_t rightPadLength) const249 void RotaryEmbeddingMasterLut::setEmbed(
250 void* rotEmbedBuffer,
251 const size_t tokenIndex,
252 const size_t tokenBatchSize,
253 const size_t leftPadLength,
254 const size_t rightPadLength) const {
255 // Generate Master Lut if not yet done
256 if (!mIsReady) {
257 ET_LOG(
258 Error,
259 "Attempting to use the rotary embedding lookup table before being initialized.");
260 return;
261 }
262 const auto requestedMaxIndex = tokenIndex + tokenBatchSize - 1;
263 const auto availableLength = getRotEmbedLength();
264 if (requestedMaxIndex >= availableLength) {
265 ET_LOG(
266 Fatal,
267 "Requested rotary embeddings (%zu) exceeds the max available (%zu) "
268 "in the master lookup table. Please ensure that your maxTokenLength option "
269 "is set correctly",
270 requestedMaxIndex,
271 availableLength);
272 }
273 // The model takes in the rot emb as [2, tokenBatchSize, kHeadDim],
274 // but the master lut stores in [tokenIdx, 2, kHeadDim].
275 const auto rowSizeBytes = 2 * kHeadDim * kTypeSize; // cos and sin
276 const auto rowSizeBytesHalf = rowSizeBytes / 2; // one of cos or sin only
277 const auto cosOffset = 0;
278 const auto sinOffset = rowSizeBytesHalf;
279 const auto copySize = rowSizeBytesHalf;
280
281 auto curRotEmbedBuffer = reinterpret_cast<char*>(rotEmbedBuffer);
282 const auto masterLutStart = mMasterLut.get() + tokenIndex * rowSizeBytes;
283
284 ET_DCHECK(tokenBatchSize >= leftPadLength + rightPadLength);
285 const size_t numValidInputToken =
286 tokenBatchSize - leftPadLength - rightPadLength;
287
288 const auto leftPadSize = copySize * leftPadLength;
289 const auto rightPadSize = copySize * rightPadLength;
290
291 // Skip left-padding
292 curRotEmbedBuffer += leftPadSize;
293
294 // cos
295 for (size_t i = 0; i < numValidInputToken; i++) {
296 std::memcpy(
297 curRotEmbedBuffer,
298 masterLutStart + i * rowSizeBytes + cosOffset,
299 copySize);
300 curRotEmbedBuffer += copySize;
301 }
302
303 // Right pad for 'cos', and left pad for 'sin'.
304 std::memset(curRotEmbedBuffer, 0, rightPadSize);
305 curRotEmbedBuffer += leftPadSize + rightPadSize;
306
307 // sin
308 for (size_t i = 0; i < numValidInputToken; i++) {
309 std::memcpy(
310 curRotEmbedBuffer,
311 masterLutStart + i * rowSizeBytes + sinOffset,
312 copySize);
313 curRotEmbedBuffer += copySize;
314 }
315
316 // Right pad for 'sin'
317 std::memset(curRotEmbedBuffer, 0, rightPadSize);
318 }
319
setEmbed(void * rotEmbedCosBuffer,void * rotEmbedSinBuffer,const size_t tokenIndex,const size_t tokenBatchSize,const size_t leftPadLength,const size_t rightPadLength) const320 void RotaryEmbeddingMasterLut::setEmbed(
321 void* rotEmbedCosBuffer,
322 void* rotEmbedSinBuffer,
323 const size_t tokenIndex,
324 const size_t tokenBatchSize,
325 const size_t leftPadLength,
326 const size_t rightPadLength) const {
327 // Generate Master Lut if not yet done
328 if (!mIsReady) {
329 ET_LOG(
330 Error,
331 "Attempting to use the rotary embedding lookup table before being initialized.");
332 return;
333 }
334 const auto requestedMaxIndex = tokenIndex + tokenBatchSize - 1;
335 const auto availableLength = getRotEmbedLength();
336 if (requestedMaxIndex >= availableLength) {
337 ET_LOG(
338 Fatal,
339 "Requested rotary embeddings (%zu) exceeds the max available (%zu) "
340 "in the master lookup table. Please ensure that your maxTokenLength option "
341 "is set correctly",
342 requestedMaxIndex,
343 availableLength);
344 }
345 // The model takes in the rot emb as [2, tokenBatchSize, kHeadDim],
346 // but the master lut stores in [tokenIdx, 2, kHeadDim].
347 const auto rowSizeBytes = 2 * kHeadDim * kTypeSize; // cos and sin
348 const auto rowSizeBytesHalf = rowSizeBytes / 2; // one of cos or sin only
349 const auto cosOffset = 0;
350 const auto sinOffset = rowSizeBytesHalf;
351 const auto copySize = rowSizeBytesHalf;
352
353 const auto masterLutStart = mMasterLut.get() + tokenIndex * rowSizeBytes;
354
355 auto curRotEmbedCosBuffer = reinterpret_cast<char*>(rotEmbedCosBuffer);
356 auto curRotEmbedSinBuffer = reinterpret_cast<char*>(rotEmbedSinBuffer);
357
358 ET_DCHECK(tokenBatchSize >= leftPadLength + rightPadLength);
359 const size_t numValidInputToken =
360 tokenBatchSize - leftPadLength - rightPadLength;
361
362 const auto leftPadSize = copySize * leftPadLength;
363 const auto rightPadSize = copySize * rightPadLength;
364
365 // Skip left-padding
366 curRotEmbedCosBuffer += leftPadSize;
367 curRotEmbedSinBuffer += leftPadSize;
368
369 for (size_t i = 0; i < numValidInputToken; i++) {
370 std::memcpy(
371 curRotEmbedCosBuffer,
372 masterLutStart + i * rowSizeBytes + cosOffset,
373 copySize);
374 std::memcpy(
375 curRotEmbedSinBuffer,
376 masterLutStart + i * rowSizeBytes + sinOffset,
377 copySize);
378 curRotEmbedCosBuffer += copySize;
379 curRotEmbedSinBuffer += copySize;
380 }
381 std::memset(curRotEmbedCosBuffer, 0, rightPadSize);
382 std::memset(curRotEmbedSinBuffer, 0, rightPadSize);
383 }
384
getRotEmbedSizeBytes(const size_t tokenBatchSize) const385 size_t RotaryEmbeddingMasterLut::getRotEmbedSizeBytes(
386 const size_t tokenBatchSize) const {
387 return 2 * tokenBatchSize * kHeadDim * kTypeSize;
388 }
389
390 // The rotary embedding length is and determines the largest token size the
391 // model can handle
getRotEmbedLength() const392 size_t RotaryEmbeddingMasterLut::getRotEmbedLength() const {
393 return kLength;
394 }
395
396 } // namespace llm_helper
397 } // namespace example
398