1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #pragma once
10 
11 #include "llm_types.h"
12 
13 #include <string>
14 #include <vector>
15 
16 namespace example {
17 namespace llm_helper {
18 
19 class RotaryEmbeddingMasterLut {
20  public:
21   RotaryEmbeddingMasterLut(
22       const LLMType rotEmbType,
23       const size_t length,
24       const size_t headDim,
25       const float rotBase = 10000.0,
26       const float ntkScale = 1.0);
27 
~RotaryEmbeddingMasterLut()28   virtual ~RotaryEmbeddingMasterLut() {}
29 
30   void load(const std::string& sinMasterPath, const std::string& cosMasterPath);
31 
32   void generate();
33 
34   template <typename RotEmbType>
35   void generate();
36 
37   virtual void setEmbed(
38       std::vector<void*> rotEmbedBuffers,
39       const size_t tokenIndex,
40       const size_t tokenBatchSize = 1,
41       const size_t leftPadLength = 0,
42       const size_t rightPadLength = 0) const;
43 
44   // Single rot emb input with combined cos & sin
45   void setEmbed(
46       void* rotEmbedBuffer,
47       const size_t tokenIndex,
48       const size_t tokenBatchSize = 1,
49       const size_t leftPadLength = 0,
50       const size_t rightPadLength = 0) const;
51 
52   // Two rot emb inputs for separated cos & sin
53   void setEmbed(
54       void* rotEmbedCosBuffer,
55       void* rotEmbedSinBuffer,
56       const size_t tokenIndex,
57       const size_t tokenBatchSize = 1,
58       const size_t leftPadLength = 0,
59       const size_t rightPadLength = 0) const;
60 
61   size_t getRotEmbedSizeBytes(const size_t tokenBatchSize = 1) const;
62 
63   // The rotary embedding length is and determines the largest token size the
64   // model can handle
65   size_t getRotEmbedLength() const;
66 
67  private:
68   std::unique_ptr<char[]> mMasterLut; // byte flatten array
69   bool mIsReady = false;
70 
71   const LLMType kType;
72   const size_t kTypeSize; // in bytes
73   const size_t kLength;
74   const size_t kHeadDim;
75   const float kRotBase = 10000.0;
76   const float kNtkScale;
77 };
78 
79 } // namespace llm_helper
80 } // namespace example
81