xref: /aosp_15_r20/external/executorch/extension/llm/tokenizer/base64.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 // @lint-ignore-every LICENSELINT
9 /**************************************************************************
10    Copyright (c) 2023 sewenew
11 
12    Licensed under the Apache License, Version 2.0 (the "License");
13    you may not use this file except in compliance with the License.
14    You may obtain a copy of the License at
15 
16        http://www.apache.org/licenses/LICENSE-2.0
17 
18    Unless required by applicable law or agreed to in writing, software
19    distributed under the License is distributed on an "AS IS" BASIS,
20    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21    See the License for the specific language governing permissions and
22    limitations under the License.
23  *************************************************************************/
24 
25 #pragma once
26 
27 #include <executorch/runtime/core/error.h>
28 #include <executorch/runtime/core/result.h>
29 #include <executorch/runtime/platform/assert.h>
30 #include <cassert>
31 #include <string>
32 #include <string_view>
33 
34 namespace executorch {
35 namespace extension {
36 namespace llm {
37 using Error = executorch::runtime::Error;
38 template <typename T>
39 using Result = executorch::runtime::Result<T>;
40 
41 namespace base64 {
42 
43 Result<std::string> decode(const std::string_view& input);
44 
45 namespace detail {
46 
47 constexpr uint32_t DECODE_TABLE[] = {
48     255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
49     255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
50     255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62,  255,
51     255, 255, 63,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  255, 255,
52     255, 255, 255, 255, 255, 0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
53     10,  11,  12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,
54     25,  255, 255, 255, 255, 255, 255, 26,  27,  28,  29,  30,  31,  32,  33,
55     34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,  48,
56     49,  50,  51,  255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
57     255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
58     255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
59     255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
60     255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
61     255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
62     255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
63     255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
64     255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
65     255};
66 
validate(uint32_t v)67 inline Error validate(uint32_t v) {
68   ET_CHECK_OR_RETURN_ERROR(v != 255, InvalidArgument, "invalid char");
69   return Error::Ok;
70 }
71 
decode(const std::string_view & input,std::string & output)72 inline Error decode(const std::string_view& input, std::string& output) {
73   ET_CHECK_OR_RETURN_ERROR(
74       input.size() == 4,
75       InvalidArgument,
76       "input length must be 4, got %zu",
77       input.size());
78 
79   uint32_t val = 0;
80 
81   uint8_t c = input[0];
82   auto v = DECODE_TABLE[c];
83   ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
84   val = v;
85 
86   c = input[1];
87   v = DECODE_TABLE[c];
88   ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
89   val = (val << 6) | v;
90 
91   c = input[2];
92   v = DECODE_TABLE[c];
93   ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
94   val = (val << 6) | v;
95 
96   c = input[3];
97   v = DECODE_TABLE[c];
98   ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
99   val = (val << 6) | v;
100 
101   output.push_back(static_cast<char>((val >> 16) & 0xFF));
102   output.push_back(static_cast<char>((val >> 8) & 0xFF));
103   output.push_back(static_cast<char>(val & 0xFF));
104   return Error::Ok;
105 }
106 
decode_1_padding(const std::string_view & input,std::string & output)107 inline Error decode_1_padding(
108     const std::string_view& input,
109     std::string& output) {
110   ET_CHECK_OR_RETURN_ERROR(
111       input.size() == 3,
112       InvalidArgument,
113       "input length must be 3, got %zu",
114       input.size());
115 
116   uint32_t val = 0;
117 
118   uint8_t c = input[0];
119   auto v = DECODE_TABLE[c];
120   ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
121   val = v;
122 
123   c = input[1];
124   v = DECODE_TABLE[c];
125   ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
126   val = (val << 6) | v;
127 
128   c = input[2];
129   v = DECODE_TABLE[c];
130   ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
131   val = (val << 6) | v;
132 
133   output.push_back(static_cast<char>((val >> 10) & 0xFF));
134   output.push_back(static_cast<char>((val >> 2) & 0xFF));
135   return Error::Ok;
136 }
137 
decode_2_padding(const std::string_view & input,std::string & output)138 inline Error decode_2_padding(
139     const std::string_view& input,
140     std::string& output) {
141   ET_CHECK_OR_RETURN_ERROR(
142       input.size() == 2,
143       InvalidArgument,
144       "input length must be 2, got %zu",
145       input.size());
146 
147   uint32_t val = 0;
148 
149   uint8_t c = input[0];
150   auto v = DECODE_TABLE[c];
151   ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
152   val = v;
153 
154   c = input[1];
155   v = DECODE_TABLE[c];
156   ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
157   val = (val << 6) | v;
158 
159   output.push_back(static_cast<char>((val >> 4) & 0xFF));
160   return Error::Ok;
161 }
162 
163 } // namespace detail
164 
decode(const std::string_view & input)165 inline Result<std::string> decode(const std::string_view& input) {
166   ET_CHECK_OR_RETURN_ERROR(!input.empty(), InvalidArgument, "empty input");
167 
168   // Faster than `input.size() % 4`.
169   ET_CHECK_OR_RETURN_ERROR(
170       (input.size() & 3) == 0 && input.size() >= 4,
171       InvalidArgument,
172       "input length must be larger than 4 and is multiple of 4, got %zu",
173       input.size());
174 
175   std::string output;
176   output.reserve(input.size() / 4 * 3);
177   auto idx = 0U;
178   for (; idx < input.size() - 4; idx += 4) {
179     ET_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output));
180   }
181 
182   // Last 4 bytes. Might contain paddings.
183   if (input[idx + 3] == '=') {
184     if (input[idx + 2] == '=') {
185       // Tow paddings.
186       ET_CHECK_OK_OR_RETURN_ERROR(
187           detail::decode_2_padding(input.substr(idx, 2), output));
188     } else {
189       // One padding.
190       ET_CHECK_OK_OR_RETURN_ERROR(
191           detail::decode_1_padding(input.substr(idx, 3), output));
192     }
193   } else {
194     // No padding.
195     ET_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output));
196   }
197 
198   return output;
199 }
200 
201 } // namespace base64
202 
203 } // namespace llm
204 } // namespace extension
205 } // namespace executorch
206 
207 namespace torch {
208 namespace executor {
209 namespace base64 {
210 // TODO(T197294990): Remove these deprecated aliases once all users have moved
211 // to the new `::executorch` namespaces.
212 using ::executorch::extension::llm::base64::decode;
213 } // namespace base64
214 } // namespace executor
215 } // namespace torch
216