1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #include "llm_helper/include/rotary_embedding.h"
10 #include "llm_helper/include/llm_types.h"
11 
12 #include <executorch/runtime/platform/assert.h>
13 #include <executorch/runtime/platform/log.h>
14 
15 #include <cmath>
16 #include <fstream>
17 #include <type_traits>
18 
19 namespace example {
20 namespace llm_helper {
21 
RotaryEmbeddingMasterLut(const LLMType rotEmbType,const size_t length,const size_t headDim,const float rotBase,const float ntkScale)22 RotaryEmbeddingMasterLut::RotaryEmbeddingMasterLut(
23     const LLMType rotEmbType,
24     const size_t length,
25     const size_t headDim,
26     const float rotBase,
27     const float ntkScale)
28     : kType(rotEmbType),
29       kTypeSize(getLLMTypeSize(kType)),
30       kLength(length),
31       kHeadDim(headDim),
32       kRotBase(rotBase),
33       kNtkScale(ntkScale) {
34   // Shape: (length, 2*headDim), where 2 is sin & cos
35   mMasterLut = std::make_unique<char[]>(kLength * 2 * kHeadDim * kTypeSize);
36 }
37 
load(const std::string & sinMasterPath,const std::string & cosMasterPath)38 void RotaryEmbeddingMasterLut::load(
39     const std::string& sinMasterPath,
40     const std::string& cosMasterPath) {
41   if (sinMasterPath.size() == 0 && cosMasterPath.size() == 0) {
42     generate();
43     return;
44   }
45 
46   ET_LOG(
47       Debug,
48       "Begin loading rotary embedding lookup table from provided paths.");
49 
50   std::ifstream fileCos(cosMasterPath, std::ios::binary);
51   std::ifstream fileSin(sinMasterPath, std::ios::binary);
52 
53   // File paths checking
54   if (!fileCos) {
55     ET_LOG(
56         Info,
57         "Warn: Rotary embedding lookup table file not found: %s. "
58         "Will generate rotary embedding lookup table instead.",
59         cosMasterPath.c_str());
60     generate();
61     return;
62   }
63   if (!fileSin) {
64     ET_LOG(
65         Info,
66         "Warn: Rotary embedding lookup table file not found: %s. "
67         "Will generate rotary embedding lookup table instead.",
68         sinMasterPath.c_str());
69     generate();
70     return;
71   }
72 
73   const auto rows = kLength;
74   const auto rowSize = 2 * kHeadDim * kTypeSize; // x2 for sin & cos
75   const size_t cosOffset = 0;
76   const size_t sinOffset =
77       rowSize / 2; // Halfway in row because each row is [<cos><sin>]
78   const auto readSize = kHeadDim * kTypeSize;
79 
80   // Read lookup table files
81   for (size_t i = 0; i < rows; ++i) {
82     // Read cos then sin
83     fileCos.read(mMasterLut.get() + i * rowSize + cosOffset, readSize);
84     fileSin.read(mMasterLut.get() + i * rowSize + sinOffset, readSize);
85   }
86   mIsReady = true;
87 }
88 
89 // For float and __fp16
90 template <typename RotEmbType>
generate()91 void RotaryEmbeddingMasterLut::generate() {
92   static_assert(
93       std::is_same<RotEmbType, float>() || std::is_same<RotEmbType, __fp16>(),
94       "Only int16/fp16/fp32 are supported for RotEmbType");
95   ET_LOG(Debug, "Generating floating rotary embedding lookup table");
96 
97   const auto rowSize = kHeadDim * 2; // x2 for sin & cos
98   const size_t rotDim = kHeadDim;
99   const size_t rotDimHalf = rotDim / 2;
100 
101   const float rotDimFp = static_cast<float>(kHeadDim);
102   const float base = (kNtkScale == 1.0f)
103       ? kRotBase
104       : std::powf(kRotBase * kNtkScale, rotDimFp / (rotDimFp - 2.0f));
105 
106   for (int pos = 0; pos < kLength; pos++) { // row in lut
107     for (int dim = 0; dim < rotDimHalf; dim++) {
108       const float freq =
109           float(pos) / std::powf(base, float(dim * 2) / rotDimFp);
110       const RotEmbType embCos = static_cast<RotEmbType>(std::cos(freq));
111       const RotEmbType embSin = static_cast<RotEmbType>(std::sin(freq));
112 
113       const auto& row = pos;
114       const auto& col = dim; // At most kHeadDim / 2
115       auto masterLutCurPtr =
116           reinterpret_cast<RotEmbType*>(mMasterLut.get()) + row * rowSize + col;
117 
118       // Concat Cos then Sin, and duplicate each
119       // Each row looks like this:
120       //   [<--cos--><--cos--><--sin--><--sin-->]
121       //    |        |        |        |
122       //    0    rotDimHalf   |        |
123       //                    rotDim     |
124       //                        rotDim + rotDimHalf
125       masterLutCurPtr[0] = embCos;
126       masterLutCurPtr[rotDimHalf] = embCos;
127       masterLutCurPtr[rotDim] = embSin;
128       masterLutCurPtr[rotDim + rotDimHalf] = embSin;
129     }
130   }
131   mIsReady = true;
132 }
133 
134 // NOTE: The difference between this and the Python script generated rotary
135 // embedding master lut is the rounding mechanism during quantization to INT16.
136 // Python's Numpy library uses round-to-even (banker's rounding) whereas the
137 // below C++ code uses round-to-nearest.
138 template <>
generate()139 void RotaryEmbeddingMasterLut::generate<int16_t>() {
140   ET_LOG(Debug, "Generating int16 rotary embedding lookup table");
141 
142   const auto rowSize = kHeadDim * 2; // x2 for sin & cos
143   const size_t rotDim = kHeadDim;
144   const size_t rotDimHalf = rotDim / 2;
145 
146   const float rotDimFp = static_cast<float>(kHeadDim);
147   const float base = (kNtkScale == 1.0f)
148       ? kRotBase
149       : std::powf(kRotBase * kNtkScale, rotDimFp / (rotDimFp - 2.0f));
150 
151   // Minmax=(-1,1), so qscale = 1/32767
152   const float qscale = 0.000030518509447574615;
153 
154   auto quantFP32ToINT16 = [&](const float fpval) -> int16_t {
155     const int qmin = -32768; // -2^(outBitwidth-1)
156     const int qmax = +32767; // 2^(outBitwidth-1)-1
157     const int quantized = std::round(fpval / qscale);
158     const int clamped = std::max(qmin, std::min(quantized, qmax));
159     return clamped;
160   };
161 
162   for (int pos = 0; pos < kLength; pos++) { // row in lut
163     for (int dim = 0; dim < rotDimHalf; dim++) {
164       const float freq =
165           float(pos) / std::powf(base, float(dim * 2) / rotDimFp);
166       const int16_t embCos = quantFP32ToINT16(std::cos(freq));
167       const int16_t embSin = quantFP32ToINT16(std::sin(freq));
168 
169       const auto& row = pos;
170       const auto& col = dim; // At most kHeadDim / 2
171       auto masterLutCurPtr =
172           reinterpret_cast<int16_t*>(mMasterLut.get()) + row * rowSize + col;
173 
174       // Concat Cos then Sin, and duplicate each
175       // Each row looks like this:
176       //   [<--cos--><--cos--><--sin--><--sin-->]
177       //    |        |        |        |
178       //    0    rotDimHalf   |        |
179       //                    rotDim     |
180       //                        rotDim + rotDimHalf
181       masterLutCurPtr[0] = embCos;
182       masterLutCurPtr[rotDimHalf] = embCos;
183       masterLutCurPtr[rotDim] = embSin;
184       masterLutCurPtr[rotDim + rotDimHalf] = embSin;
185     }
186   }
187   mIsReady = true;
188 }
189 
generate()190 void RotaryEmbeddingMasterLut::generate() {
191   switch (kType) {
192     case LLMType::INT16:
193       generate<int16_t>();
194       return;
195     case LLMType::FP16:
196       generate<__fp16>();
197       return;
198     case LLMType::FP32:
199       generate<float>();
200       return;
201     default:
202       break;
203   }
204   ET_LOG(
205       Fatal,
206       "Rotary embedding generator not implemented for %s",
207       getLLMTypeName(kType));
208 }
209 
210 // RotaryEmbeddingMasterLut supports 1 or 2 rotary embedding inputs
setEmbed(std::vector<void * > rotEmbedBuffers,const size_t tokenIndex,const size_t tokenBatchSize,const size_t leftPadLength,const size_t rightPadLength) const211 void RotaryEmbeddingMasterLut::setEmbed(
212     std::vector<void*> rotEmbedBuffers,
213     const size_t tokenIndex,
214     const size_t tokenBatchSize,
215     const size_t leftPadLength,
216     const size_t rightPadLength) const {
217   const auto numRotEmbInputs = rotEmbedBuffers.size();
218   switch (numRotEmbInputs) {
219     case 1: {
220       const auto rotEmbInput = rotEmbedBuffers[0];
221       setEmbed(
222           rotEmbInput,
223           tokenIndex,
224           tokenBatchSize,
225           leftPadLength,
226           rightPadLength);
227       break;
228     }
229     case 2: {
230       const auto rotEmbCosInput = rotEmbedBuffers[0];
231       const auto rotEmbSinInput = rotEmbedBuffers[1];
232       setEmbed(
233           rotEmbCosInput,
234           rotEmbSinInput,
235           tokenIndex,
236           tokenBatchSize,
237           leftPadLength,
238           rightPadLength);
239       break;
240     }
241     default:
242       ET_LOG(
243           Fatal,
244           "RotaryEmbeddingMasterLut: Unsupported number of rotary embedding inputs (%zu).",
245           numRotEmbInputs);
246   }
247 }
248 
setEmbed(void * rotEmbedBuffer,const size_t tokenIndex,const size_t tokenBatchSize,const size_t leftPadLength,const size_t rightPadLength) const249 void RotaryEmbeddingMasterLut::setEmbed(
250     void* rotEmbedBuffer,
251     const size_t tokenIndex,
252     const size_t tokenBatchSize,
253     const size_t leftPadLength,
254     const size_t rightPadLength) const {
255   // Generate Master Lut if not yet done
256   if (!mIsReady) {
257     ET_LOG(
258         Error,
259         "Attempting to use the rotary embedding lookup table before being initialized.");
260     return;
261   }
262   const auto requestedMaxIndex = tokenIndex + tokenBatchSize - 1;
263   const auto availableLength = getRotEmbedLength();
264   if (requestedMaxIndex >= availableLength) {
265     ET_LOG(
266         Fatal,
267         "Requested rotary embeddings (%zu) exceeds the max available (%zu) "
268         "in the master lookup table. Please ensure that your maxTokenLength option "
269         "is set correctly",
270         requestedMaxIndex,
271         availableLength);
272   }
273   // The model takes in the rot emb as [2, tokenBatchSize, kHeadDim],
274   // but the master lut stores in [tokenIdx, 2, kHeadDim].
275   const auto rowSizeBytes = 2 * kHeadDim * kTypeSize; // cos and sin
276   const auto rowSizeBytesHalf = rowSizeBytes / 2; // one of cos or sin only
277   const auto cosOffset = 0;
278   const auto sinOffset = rowSizeBytesHalf;
279   const auto copySize = rowSizeBytesHalf;
280 
281   auto curRotEmbedBuffer = reinterpret_cast<char*>(rotEmbedBuffer);
282   const auto masterLutStart = mMasterLut.get() + tokenIndex * rowSizeBytes;
283 
284   ET_DCHECK(tokenBatchSize >= leftPadLength + rightPadLength);
285   const size_t numValidInputToken =
286       tokenBatchSize - leftPadLength - rightPadLength;
287 
288   const auto leftPadSize = copySize * leftPadLength;
289   const auto rightPadSize = copySize * rightPadLength;
290 
291   // Skip left-padding
292   curRotEmbedBuffer += leftPadSize;
293 
294   // cos
295   for (size_t i = 0; i < numValidInputToken; i++) {
296     std::memcpy(
297         curRotEmbedBuffer,
298         masterLutStart + i * rowSizeBytes + cosOffset,
299         copySize);
300     curRotEmbedBuffer += copySize;
301   }
302 
303   // Right pad for 'cos', and left pad for 'sin'.
304   std::memset(curRotEmbedBuffer, 0, rightPadSize);
305   curRotEmbedBuffer += leftPadSize + rightPadSize;
306 
307   // sin
308   for (size_t i = 0; i < numValidInputToken; i++) {
309     std::memcpy(
310         curRotEmbedBuffer,
311         masterLutStart + i * rowSizeBytes + sinOffset,
312         copySize);
313     curRotEmbedBuffer += copySize;
314   }
315 
316   // Right pad for 'sin'
317   std::memset(curRotEmbedBuffer, 0, rightPadSize);
318 }
319 
setEmbed(void * rotEmbedCosBuffer,void * rotEmbedSinBuffer,const size_t tokenIndex,const size_t tokenBatchSize,const size_t leftPadLength,const size_t rightPadLength) const320 void RotaryEmbeddingMasterLut::setEmbed(
321     void* rotEmbedCosBuffer,
322     void* rotEmbedSinBuffer,
323     const size_t tokenIndex,
324     const size_t tokenBatchSize,
325     const size_t leftPadLength,
326     const size_t rightPadLength) const {
327   // Generate Master Lut if not yet done
328   if (!mIsReady) {
329     ET_LOG(
330         Error,
331         "Attempting to use the rotary embedding lookup table before being initialized.");
332     return;
333   }
334   const auto requestedMaxIndex = tokenIndex + tokenBatchSize - 1;
335   const auto availableLength = getRotEmbedLength();
336   if (requestedMaxIndex >= availableLength) {
337     ET_LOG(
338         Fatal,
339         "Requested rotary embeddings (%zu) exceeds the max available (%zu) "
340         "in the master lookup table. Please ensure that your maxTokenLength option "
341         "is set correctly",
342         requestedMaxIndex,
343         availableLength);
344   }
345   // The model takes in the rot emb as [2, tokenBatchSize, kHeadDim],
346   // but the master lut stores in [tokenIdx, 2, kHeadDim].
347   const auto rowSizeBytes = 2 * kHeadDim * kTypeSize; // cos and sin
348   const auto rowSizeBytesHalf = rowSizeBytes / 2; // one of cos or sin only
349   const auto cosOffset = 0;
350   const auto sinOffset = rowSizeBytesHalf;
351   const auto copySize = rowSizeBytesHalf;
352 
353   const auto masterLutStart = mMasterLut.get() + tokenIndex * rowSizeBytes;
354 
355   auto curRotEmbedCosBuffer = reinterpret_cast<char*>(rotEmbedCosBuffer);
356   auto curRotEmbedSinBuffer = reinterpret_cast<char*>(rotEmbedSinBuffer);
357 
358   ET_DCHECK(tokenBatchSize >= leftPadLength + rightPadLength);
359   const size_t numValidInputToken =
360       tokenBatchSize - leftPadLength - rightPadLength;
361 
362   const auto leftPadSize = copySize * leftPadLength;
363   const auto rightPadSize = copySize * rightPadLength;
364 
365   // Skip left-padding
366   curRotEmbedCosBuffer += leftPadSize;
367   curRotEmbedSinBuffer += leftPadSize;
368 
369   for (size_t i = 0; i < numValidInputToken; i++) {
370     std::memcpy(
371         curRotEmbedCosBuffer,
372         masterLutStart + i * rowSizeBytes + cosOffset,
373         copySize);
374     std::memcpy(
375         curRotEmbedSinBuffer,
376         masterLutStart + i * rowSizeBytes + sinOffset,
377         copySize);
378     curRotEmbedCosBuffer += copySize;
379     curRotEmbedSinBuffer += copySize;
380   }
381   std::memset(curRotEmbedCosBuffer, 0, rightPadSize);
382   std::memset(curRotEmbedSinBuffer, 0, rightPadSize);
383 }
384 
getRotEmbedSizeBytes(const size_t tokenBatchSize) const385 size_t RotaryEmbeddingMasterLut::getRotEmbedSizeBytes(
386     const size_t tokenBatchSize) const {
387   return 2 * tokenBatchSize * kHeadDim * kTypeSize;
388 }
389 
390 // The rotary embedding length is and determines the largest token size the
391 // model can handle
getRotEmbedLength() const392 size_t RotaryEmbeddingMasterLut::getRotEmbedLength() const {
393   return kLength;
394 }
395 
396 } // namespace llm_helper
397 } // namespace example
398