1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #pragma once
10 
11 #include "llm_types.h"
12 
13 #include <string>
14 
15 namespace example {
16 namespace llm_helper {
17 
18 class MaskBuilder {
19  public:
20   explicit MaskBuilder(
21       void* maskBuffer,
22       const size_t maskSizeBytes,
23       const LLMType maskType,
24       const size_t cacheLength);
25 
26   ~MaskBuilder();
27 
28   // Build mask from scratch.
29   void buildMask(const size_t tokenBatchSize, const size_t numSeenToken);
30 
31   // Only set mask to true for seen tokens.
32   // Will fallback to buildMask if mask is not updatable.
33   void updateMask(
34       const size_t tokenBatchSize,
35       const size_t numSeenToken,
36       const size_t length);
37 
38   void notifyLeftPadding(const size_t padLength);
39 
40   void notifyRightPadding(const size_t padLength);
41 
42   // Mark mask as non-updatable which forces updateMask to call buildMask.
43   void markMaskDirty();
44 
45   // Update the model input mask size. Use raw byte size to account for any HW
46   // alignment.
47   void updateMaskSize(const size_t sizeBytes);
48 
49  private:
50   template <typename MaskType>
51   void buildMask(const size_t tokenBatchSize, const size_t numSeenToken);
52 
53   template <typename MaskType>
54   void updateMask(
55       const size_t tokenBatchSize,
56       const size_t numSeenToken,
57       const size_t length);
58 
59   // Adjust mask for padded input, and returns whether mask is modified for
60   // padding. Used by buildMask/updateMask.
61   template <typename MaskType>
62   bool adjustMaskForPadding(const size_t tokenBatchSize);
63 
64  private:
65   void* mMaskBuffer;
66   size_t mMaskSizeBytes;
67   const LLMType kMaskType;
68   const size_t kMaskTypeSize;
69   const size_t kCacheLength;
70 
71   // Set by notifyLeftPadding/notifyRightPadding. Reset by adjustMaskForPadding.
72   size_t mLeftPadLength = 0;
73   size_t mRightPadLength = 0;
74 
75   bool mIsMaskUpdatable = false;
76 };
77 
78 } // namespace llm_helper
79 } // namespace example
80