1 /* 2 * Copyright (c) 2024 MediaTek Inc. 3 * 4 * Licensed under the BSD License (the "License"); you may not use this file 5 * except in compliance with the License. See the license file in the root 6 * directory of this source tree for more details. 7 */ 8 9 #pragma once 10 11 #include "llm_types.h" 12 13 #include <string> 14 #include <vector> 15 16 namespace example { 17 namespace llm_helper { 18 19 class RotaryEmbeddingMasterLut { 20 public: 21 RotaryEmbeddingMasterLut( 22 const LLMType rotEmbType, 23 const size_t length, 24 const size_t headDim, 25 const float rotBase = 10000.0, 26 const float ntkScale = 1.0); 27 ~RotaryEmbeddingMasterLut()28 virtual ~RotaryEmbeddingMasterLut() {} 29 30 void load(const std::string& sinMasterPath, const std::string& cosMasterPath); 31 32 void generate(); 33 34 template <typename RotEmbType> 35 void generate(); 36 37 virtual void setEmbed( 38 std::vector<void*> rotEmbedBuffers, 39 const size_t tokenIndex, 40 const size_t tokenBatchSize = 1, 41 const size_t leftPadLength = 0, 42 const size_t rightPadLength = 0) const; 43 44 // Single rot emb input with combined cos & sin 45 void setEmbed( 46 void* rotEmbedBuffer, 47 const size_t tokenIndex, 48 const size_t tokenBatchSize = 1, 49 const size_t leftPadLength = 0, 50 const size_t rightPadLength = 0) const; 51 52 // Two rot emb inputs for separated cos & sin 53 void setEmbed( 54 void* rotEmbedCosBuffer, 55 void* rotEmbedSinBuffer, 56 const size_t tokenIndex, 57 const size_t tokenBatchSize = 1, 58 const size_t leftPadLength = 0, 59 const size_t rightPadLength = 0) const; 60 61 size_t getRotEmbedSizeBytes(const size_t tokenBatchSize = 1) const; 62 63 // The rotary embedding length is and determines the largest token size the 64 // model can handle 65 size_t getRotEmbedLength() const; 66 67 private: 68 std::unique_ptr<char[]> mMasterLut; // byte flatten array 69 bool mIsReady = false; 70 71 const LLMType kType; 72 const size_t kTypeSize; // in bytes 73 const size_t kLength; 74 const size_t kHeadDim; 75 const float kRotBase = 10000.0; 76 const float kNtkScale; 77 }; 78 79 } // namespace llm_helper 80 } // namespace example 81