1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8 // @lint-ignore-every LICENSELINT
9 /**************************************************************************
10 Copyright (c) 2023 sewenew
11
12 Licensed under the Apache License, Version 2.0 (the "License");
13 you may not use this file except in compliance with the License.
14 You may obtain a copy of the License at
15
16 http://www.apache.org/licenses/LICENSE-2.0
17
18 Unless required by applicable law or agreed to in writing, software
19 distributed under the License is distributed on an "AS IS" BASIS,
20 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21 See the License for the specific language governing permissions and
22 limitations under the License.
23 *************************************************************************/
24
25 #pragma once
26
27 #include <executorch/runtime/core/error.h>
28 #include <executorch/runtime/core/result.h>
29 #include <executorch/runtime/platform/assert.h>
30 #include <cassert>
31 #include <string>
32 #include <string_view>
33
34 namespace executorch {
35 namespace extension {
36 namespace llm {
37 using Error = executorch::runtime::Error;
38 template <typename T>
39 using Result = executorch::runtime::Result<T>;
40
41 namespace base64 {
42
43 Result<std::string> decode(const std::string_view& input);
44
45 namespace detail {
46
47 constexpr uint32_t DECODE_TABLE[] = {
48 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
49 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
50 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255,
51 255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255,
52 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
53 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
54 25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33,
55 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
56 49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
57 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
58 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
59 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
60 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
61 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
62 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
63 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
64 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
65 255};
66
validate(uint32_t v)67 inline Error validate(uint32_t v) {
68 ET_CHECK_OR_RETURN_ERROR(v != 255, InvalidArgument, "invalid char");
69 return Error::Ok;
70 }
71
decode(const std::string_view & input,std::string & output)72 inline Error decode(const std::string_view& input, std::string& output) {
73 ET_CHECK_OR_RETURN_ERROR(
74 input.size() == 4,
75 InvalidArgument,
76 "input length must be 4, got %zu",
77 input.size());
78
79 uint32_t val = 0;
80
81 uint8_t c = input[0];
82 auto v = DECODE_TABLE[c];
83 ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
84 val = v;
85
86 c = input[1];
87 v = DECODE_TABLE[c];
88 ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
89 val = (val << 6) | v;
90
91 c = input[2];
92 v = DECODE_TABLE[c];
93 ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
94 val = (val << 6) | v;
95
96 c = input[3];
97 v = DECODE_TABLE[c];
98 ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
99 val = (val << 6) | v;
100
101 output.push_back(static_cast<char>((val >> 16) & 0xFF));
102 output.push_back(static_cast<char>((val >> 8) & 0xFF));
103 output.push_back(static_cast<char>(val & 0xFF));
104 return Error::Ok;
105 }
106
decode_1_padding(const std::string_view & input,std::string & output)107 inline Error decode_1_padding(
108 const std::string_view& input,
109 std::string& output) {
110 ET_CHECK_OR_RETURN_ERROR(
111 input.size() == 3,
112 InvalidArgument,
113 "input length must be 3, got %zu",
114 input.size());
115
116 uint32_t val = 0;
117
118 uint8_t c = input[0];
119 auto v = DECODE_TABLE[c];
120 ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
121 val = v;
122
123 c = input[1];
124 v = DECODE_TABLE[c];
125 ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
126 val = (val << 6) | v;
127
128 c = input[2];
129 v = DECODE_TABLE[c];
130 ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
131 val = (val << 6) | v;
132
133 output.push_back(static_cast<char>((val >> 10) & 0xFF));
134 output.push_back(static_cast<char>((val >> 2) & 0xFF));
135 return Error::Ok;
136 }
137
decode_2_padding(const std::string_view & input,std::string & output)138 inline Error decode_2_padding(
139 const std::string_view& input,
140 std::string& output) {
141 ET_CHECK_OR_RETURN_ERROR(
142 input.size() == 2,
143 InvalidArgument,
144 "input length must be 2, got %zu",
145 input.size());
146
147 uint32_t val = 0;
148
149 uint8_t c = input[0];
150 auto v = DECODE_TABLE[c];
151 ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
152 val = v;
153
154 c = input[1];
155 v = DECODE_TABLE[c];
156 ET_CHECK_OK_OR_RETURN_ERROR(validate(v));
157 val = (val << 6) | v;
158
159 output.push_back(static_cast<char>((val >> 4) & 0xFF));
160 return Error::Ok;
161 }
162
163 } // namespace detail
164
decode(const std::string_view & input)165 inline Result<std::string> decode(const std::string_view& input) {
166 ET_CHECK_OR_RETURN_ERROR(!input.empty(), InvalidArgument, "empty input");
167
168 // Faster than `input.size() % 4`.
169 ET_CHECK_OR_RETURN_ERROR(
170 (input.size() & 3) == 0 && input.size() >= 4,
171 InvalidArgument,
172 "input length must be larger than 4 and is multiple of 4, got %zu",
173 input.size());
174
175 std::string output;
176 output.reserve(input.size() / 4 * 3);
177 auto idx = 0U;
178 for (; idx < input.size() - 4; idx += 4) {
179 ET_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output));
180 }
181
182 // Last 4 bytes. Might contain paddings.
183 if (input[idx + 3] == '=') {
184 if (input[idx + 2] == '=') {
185 // Tow paddings.
186 ET_CHECK_OK_OR_RETURN_ERROR(
187 detail::decode_2_padding(input.substr(idx, 2), output));
188 } else {
189 // One padding.
190 ET_CHECK_OK_OR_RETURN_ERROR(
191 detail::decode_1_padding(input.substr(idx, 3), output));
192 }
193 } else {
194 // No padding.
195 ET_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output));
196 }
197
198 return output;
199 }
200
201 } // namespace base64
202
203 } // namespace llm
204 } // namespace extension
205 } // namespace executorch
206
207 namespace torch {
208 namespace executor {
209 namespace base64 {
210 // TODO(T197294990): Remove these deprecated aliases once all users have moved
211 // to the new `::executorch` namespaces.
212 using ::executorch::extension::llm::base64::decode;
213 } // namespace base64
214 } // namespace executor
215 } // namespace torch
216