xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #pragma once
10 
11 #include <string>
12 #include <unordered_map>
13 #include <vector>
14 
15 #include <executorch/extension/data_loader/file_data_loader.h>
16 #include <executorch/extension/evalue_util/print_evalue.h>
17 #include <executorch/runtime/executor/method.h>
18 #include <executorch/runtime/executor/program.h>
19 #include <executorch/runtime/platform/log.h>
20 #include <executorch/runtime/platform/profiler.h>
21 #include <executorch/runtime/platform/runtime.h>
22 
23 #include "LlamaConfig.h"
24 #include "ModelChunk.h"
25 #include "llm_helper/include/llm_types.h"
26 
27 #include "llm_helper/include/mask_builder.h"
28 #include "llm_helper/include/rotary_embedding.h"
29 
30 namespace example {
31 
32 using llm_helper::MaskBuilder;
33 using llm_helper::RotaryEmbeddingMasterLut;
34 
35 using TensorShape = executorch::runtime::Span<const int32_t>;
36 using ModelIndexMap = std::unordered_map<size_t, size_t>;
37 
38 // Llama decoder chunk
39 class LlamaModelChunk : public ModelChunk {
40  private:
41   static constexpr size_t kCacheLengthDim = 2;
42 
43  public:
44   explicit LlamaModelChunk(
45       const ModelPathMap& modelPathMap,
46       const LlamaModelOptions& modelOptions,
47       const size_t initBatchSize,
48       const size_t numCache,
49       const size_t numRotEmbInputs,
50       const RotaryEmbeddingMasterLut* rotEmbMasterLut);
51 
52   ~LlamaModelChunk();
53 
54   virtual void Initialize() override;
55 
56   virtual void Run() override;
57 
58   virtual bool HotSwapModel(const size_t tokenBatchSize) override;
59 
60   void Reset();
61 
62   void SetLeftPadding(const size_t leftPadSize);
63 
64   void SetRightPadding(const size_t rightPadSize);
65 
66   void UpdatePosEmbAndMask(const size_t numInputToken);
67 
68   void AdvanceTokenIndex();
69 
70   size_t GetTokenIndex() const;
71 
72  private:
73   void SetPosEmbed(const size_t tokenIndex);
74 
75   void InitMaskBuilder();
76 
77   void InitCache();
78 
79   void PrepareCacheIOs();
80 
81   size_t GetCacheStrideSize() const;
82 
83   size_t GetCacheNumRows() const;
84 
85   size_t GetLeftPadding() const;
86 
87   size_t GetRightPadding() const;
88 
89   void PaddingPostprocess();
90 
91   virtual void LeftPaddingCachePostprocess();
92 
93   virtual void RightPaddingCachePostprocess();
94 
95   virtual void RollbackCache(
96       const size_t rollbackTokCount,
97       const size_t numSeenToken);
98 
99  private:
100   void CheckIoCount();
101 
102   size_t GetExpectedInputCount() const;
103 
104   size_t GetExpectedOutputCount() const;
105 
106  private:
107   // Input/Output Indexes
108   const size_t kMaskInputIndex;
109   const std::vector<size_t> kRotEmbInputIndexes;
110   const std::vector<size_t> kCacheInputIndexes;
111   const std::vector<size_t> kCacheOutputIndexes;
112 
113   // Cache
114   TensorShape mCacheShape;
115   const LLMType kCacheType;
116   const size_t kMaxTokenLength;
117   const size_t kCacheLength;
118   const size_t kCacheTypeSize;
119 
120   // Mask
121   const LLMType kMaskType;
122 
123   // Padding
124   size_t mCurrentPadSize = 0;
125   enum class PaddingMode { LEFT, RIGHT };
126   PaddingMode mPaddingMode = PaddingMode::RIGHT;
127 
128   // Lookup table for rotary embedding
129   const RotaryEmbeddingMasterLut* kRotEmbMasterLut;
130 
131   // Mask builder
132   std::unique_ptr<MaskBuilder> mMaskBuilder;
133 
134   // Keep track of token index. Its value can also be viewed as numSeenToken.
135   size_t mCurrentTokenIndex = 0;
136 };
137 
138 } // namespace example
139