1 /* 2 * Copyright (c) 2024 MediaTek Inc. 3 * 4 * Licensed under the BSD License (the "License"); you may not use this file 5 * except in compliance with the License. See the license file in the root 6 * directory of this source tree for more details. 7 */ 8 9 #pragma once 10 11 #include <string> 12 #include <unordered_map> 13 #include <vector> 14 15 #include <executorch/extension/data_loader/file_data_loader.h> 16 #include <executorch/extension/evalue_util/print_evalue.h> 17 #include <executorch/runtime/executor/method.h> 18 #include <executorch/runtime/executor/program.h> 19 #include <executorch/runtime/platform/log.h> 20 #include <executorch/runtime/platform/profiler.h> 21 #include <executorch/runtime/platform/runtime.h> 22 23 #include "LlamaConfig.h" 24 #include "ModelChunk.h" 25 #include "llm_helper/include/llm_types.h" 26 27 #include "llm_helper/include/mask_builder.h" 28 #include "llm_helper/include/rotary_embedding.h" 29 30 namespace example { 31 32 using llm_helper::MaskBuilder; 33 using llm_helper::RotaryEmbeddingMasterLut; 34 35 using TensorShape = executorch::runtime::Span<const int32_t>; 36 using ModelIndexMap = std::unordered_map<size_t, size_t>; 37 38 // Llama decoder chunk 39 class LlamaModelChunk : public ModelChunk { 40 private: 41 static constexpr size_t kCacheLengthDim = 2; 42 43 public: 44 explicit LlamaModelChunk( 45 const ModelPathMap& modelPathMap, 46 const LlamaModelOptions& modelOptions, 47 const size_t initBatchSize, 48 const size_t numCache, 49 const size_t numRotEmbInputs, 50 const RotaryEmbeddingMasterLut* rotEmbMasterLut); 51 52 ~LlamaModelChunk(); 53 54 virtual void Initialize() override; 55 56 virtual void Run() override; 57 58 virtual bool HotSwapModel(const size_t tokenBatchSize) override; 59 60 void Reset(); 61 62 void SetLeftPadding(const size_t leftPadSize); 63 64 void SetRightPadding(const size_t rightPadSize); 65 66 void UpdatePosEmbAndMask(const size_t numInputToken); 67 68 void AdvanceTokenIndex(); 69 70 size_t GetTokenIndex() const; 71 72 private: 73 void SetPosEmbed(const size_t tokenIndex); 74 75 void InitMaskBuilder(); 76 77 void InitCache(); 78 79 void PrepareCacheIOs(); 80 81 size_t GetCacheStrideSize() const; 82 83 size_t GetCacheNumRows() const; 84 85 size_t GetLeftPadding() const; 86 87 size_t GetRightPadding() const; 88 89 void PaddingPostprocess(); 90 91 virtual void LeftPaddingCachePostprocess(); 92 93 virtual void RightPaddingCachePostprocess(); 94 95 virtual void RollbackCache( 96 const size_t rollbackTokCount, 97 const size_t numSeenToken); 98 99 private: 100 void CheckIoCount(); 101 102 size_t GetExpectedInputCount() const; 103 104 size_t GetExpectedOutputCount() const; 105 106 private: 107 // Input/Output Indexes 108 const size_t kMaskInputIndex; 109 const std::vector<size_t> kRotEmbInputIndexes; 110 const std::vector<size_t> kCacheInputIndexes; 111 const std::vector<size_t> kCacheOutputIndexes; 112 113 // Cache 114 TensorShape mCacheShape; 115 const LLMType kCacheType; 116 const size_t kMaxTokenLength; 117 const size_t kCacheLength; 118 const size_t kCacheTypeSize; 119 120 // Mask 121 const LLMType kMaskType; 122 123 // Padding 124 size_t mCurrentPadSize = 0; 125 enum class PaddingMode { LEFT, RIGHT }; 126 PaddingMode mPaddingMode = PaddingMode::RIGHT; 127 128 // Lookup table for rotary embedding 129 const RotaryEmbeddingMasterLut* kRotEmbMasterLut; 130 131 // Mask builder 132 std::unique_ptr<MaskBuilder> mMaskBuilder; 133 134 // Keep track of token index. Its value can also be viewed as numSeenToken. 135 size_t mCurrentTokenIndex = 0; 136 }; 137 138 } // namespace example 139