1 /* 2 * Copyright (c) 2024 MediaTek Inc. 3 * 4 * Licensed under the BSD License (the "License"); you may not use this file 5 * except in compliance with the License. See the license file in the root 6 * directory of this source tree for more details. 7 */ 8 9 #pragma once 10 11 #include "llm_types.h" 12 13 #include <string> 14 15 namespace example { 16 namespace llm_helper { 17 18 class MaskBuilder { 19 public: 20 explicit MaskBuilder( 21 void* maskBuffer, 22 const size_t maskSizeBytes, 23 const LLMType maskType, 24 const size_t cacheLength); 25 26 ~MaskBuilder(); 27 28 // Build mask from scratch. 29 void buildMask(const size_t tokenBatchSize, const size_t numSeenToken); 30 31 // Only set mask to true for seen tokens. 32 // Will fallback to buildMask if mask is not updatable. 33 void updateMask( 34 const size_t tokenBatchSize, 35 const size_t numSeenToken, 36 const size_t length); 37 38 void notifyLeftPadding(const size_t padLength); 39 40 void notifyRightPadding(const size_t padLength); 41 42 // Mark mask as non-updatable which forces updateMask to call buildMask. 43 void markMaskDirty(); 44 45 // Update the model input mask size. Use raw byte size to account for any HW 46 // alignment. 47 void updateMaskSize(const size_t sizeBytes); 48 49 private: 50 template <typename MaskType> 51 void buildMask(const size_t tokenBatchSize, const size_t numSeenToken); 52 53 template <typename MaskType> 54 void updateMask( 55 const size_t tokenBatchSize, 56 const size_t numSeenToken, 57 const size_t length); 58 59 // Adjust mask for padded input, and returns whether mask is modified for 60 // padding. Used by buildMask/updateMask. 61 template <typename MaskType> 62 bool adjustMaskForPadding(const size_t tokenBatchSize); 63 64 private: 65 void* mMaskBuffer; 66 size_t mMaskSizeBytes; 67 const LLMType kMaskType; 68 const size_t kMaskTypeSize; 69 const size_t kCacheLength; 70 71 // Set by notifyLeftPadding/notifyRightPadding. Reset by adjustMaskForPadding. 72 size_t mLeftPadLength = 0; 73 size_t mRightPadLength = 0; 74 75 bool mIsMaskUpdatable = false; 76 }; 77 78 } // namespace llm_helper 79 } // namespace example 80