1 /* 2 * Copyright (c) 2024 MediaTek Inc. 3 * 4 * Licensed under the BSD License (the "License"); you may not use this file 5 * except in compliance with the License. See the license file in the root 6 * directory of this source tree for more details. 7 */ 8 9 #pragma once 10 11 #include <string> 12 #include <vector> 13 14 #include <executorch/runtime/platform/log.h> 15 16 #include "LlamaConfig.h" 17 #include "LlamaModelChunk.h" 18 #include "llm_helper/include/llm_types.h" 19 20 #include "llm_helper/include/rotary_embedding.h" 21 #include "llm_helper/include/token_embedding.h" 22 23 namespace example { 24 25 class LlamaRuntime { 26 public: LlamaRuntime()27 explicit LlamaRuntime() {} ~LlamaRuntime()28 ~LlamaRuntime() {} 29 30 void Initialize( 31 const LlamaModelOptions& modelOptions, 32 const LlamaModelPaths& modelPaths); 33 34 void Release(); 35 36 void SwapModel(const size_t batchSize); 37 38 void* Run( 39 const std::vector<uint64_t>& inputTokens, 40 const bool lastLogits = true); 41 42 void Reset(); 43 44 size_t GetTokenBatchSize() const; 45 46 size_t GetTokenIndex() const; 47 48 const LlamaModelOptions& GetModelOptions() const; 49 50 private: 51 LlamaModelOptions mModelOptions; 52 std::vector<std::unique_ptr<ModelChunk>> mLlamaModelChunks; 53 std::unique_ptr<llm_helper::TokenEmbeddingLut> mTokenEmbLut; 54 std::unique_ptr<llm_helper::RotaryEmbeddingMasterLut> mRotEmbMasterLut; 55 size_t mTokenBatchSize = 1; 56 size_t mTokenIndex = 0; 57 }; 58 59 } // namespace example 60