xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/llama_runner/LlamaRuntime.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #pragma once
10 
11 #include <string>
12 #include <vector>
13 
14 #include <executorch/runtime/platform/log.h>
15 
16 #include "LlamaConfig.h"
17 #include "LlamaModelChunk.h"
18 #include "llm_helper/include/llm_types.h"
19 
20 #include "llm_helper/include/rotary_embedding.h"
21 #include "llm_helper/include/token_embedding.h"
22 
23 namespace example {
24 
25 class LlamaRuntime {
26  public:
LlamaRuntime()27   explicit LlamaRuntime() {}
~LlamaRuntime()28   ~LlamaRuntime() {}
29 
30   void Initialize(
31       const LlamaModelOptions& modelOptions,
32       const LlamaModelPaths& modelPaths);
33 
34   void Release();
35 
36   void SwapModel(const size_t batchSize);
37 
38   void* Run(
39       const std::vector<uint64_t>& inputTokens,
40       const bool lastLogits = true);
41 
42   void Reset();
43 
44   size_t GetTokenBatchSize() const;
45 
46   size_t GetTokenIndex() const;
47 
48   const LlamaModelOptions& GetModelOptions() const;
49 
50  private:
51   LlamaModelOptions mModelOptions;
52   std::vector<std::unique_ptr<ModelChunk>> mLlamaModelChunks;
53   std::unique_ptr<llm_helper::TokenEmbeddingLut> mTokenEmbLut;
54   std::unique_ptr<llm_helper::RotaryEmbeddingMasterLut> mRotEmbMasterLut;
55   size_t mTokenBatchSize = 1;
56   size_t mTokenIndex = 0;
57 };
58 
59 } // namespace example
60