1 use crate::{
2     alphabet::Alphabet,
3     engine::{
4         general_purpose::{self, decode_table, encode_table},
5         Config, DecodeEstimate, DecodeMetadata, DecodePaddingMode, Engine,
6     },
7     DecodeError, DecodeSliceError,
8 };
9 use std::ops::{BitAnd, BitOr, Shl, Shr};
10 
11 /// Comparatively simple implementation that can be used as something to compare against in tests
12 pub struct Naive {
13     encode_table: [u8; 64],
14     decode_table: [u8; 256],
15     config: NaiveConfig,
16 }
17 
18 impl Naive {
19     const ENCODE_INPUT_CHUNK_SIZE: usize = 3;
20     const DECODE_INPUT_CHUNK_SIZE: usize = 4;
21 
new(alphabet: &Alphabet, config: NaiveConfig) -> Self22     pub const fn new(alphabet: &Alphabet, config: NaiveConfig) -> Self {
23         Self {
24             encode_table: encode_table(alphabet),
25             decode_table: decode_table(alphabet),
26             config,
27         }
28     }
29 
decode_byte_into_u32(&self, offset: usize, byte: u8) -> Result<u32, DecodeError>30     fn decode_byte_into_u32(&self, offset: usize, byte: u8) -> Result<u32, DecodeError> {
31         let decoded = self.decode_table[byte as usize];
32 
33         if decoded == general_purpose::INVALID_VALUE {
34             return Err(DecodeError::InvalidByte(offset, byte));
35         }
36 
37         Ok(decoded as u32)
38     }
39 }
40 
41 impl Engine for Naive {
42     type Config = NaiveConfig;
43     type DecodeEstimate = NaiveEstimate;
44 
internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize45     fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize {
46         // complete chunks first
47 
48         const LOW_SIX_BITS: u32 = 0x3F;
49 
50         let rem = input.len() % Self::ENCODE_INPUT_CHUNK_SIZE;
51         // will never underflow
52         let complete_chunk_len = input.len() - rem;
53 
54         let mut input_index = 0_usize;
55         let mut output_index = 0_usize;
56         if let Some(last_complete_chunk_index) =
57             complete_chunk_len.checked_sub(Self::ENCODE_INPUT_CHUNK_SIZE)
58         {
59             while input_index <= last_complete_chunk_index {
60                 let chunk = &input[input_index..input_index + Self::ENCODE_INPUT_CHUNK_SIZE];
61 
62                 // populate low 24 bits from 3 bytes
63                 let chunk_int: u32 =
64                     (chunk[0] as u32).shl(16) | (chunk[1] as u32).shl(8) | (chunk[2] as u32);
65                 // encode 4x 6-bit output bytes
66                 output[output_index] = self.encode_table[chunk_int.shr(18) as usize];
67                 output[output_index + 1] =
68                     self.encode_table[chunk_int.shr(12_u8).bitand(LOW_SIX_BITS) as usize];
69                 output[output_index + 2] =
70                     self.encode_table[chunk_int.shr(6_u8).bitand(LOW_SIX_BITS) as usize];
71                 output[output_index + 3] =
72                     self.encode_table[chunk_int.bitand(LOW_SIX_BITS) as usize];
73 
74                 input_index += Self::ENCODE_INPUT_CHUNK_SIZE;
75                 output_index += 4;
76             }
77         }
78 
79         // then leftovers
80         if rem == 2 {
81             let chunk = &input[input_index..input_index + 2];
82 
83             // high six bits of chunk[0]
84             output[output_index] = self.encode_table[chunk[0].shr(2) as usize];
85             // bottom 2 bits of [0], high 4 bits of [1]
86             output[output_index + 1] =
87                 self.encode_table[(chunk[0].shl(4_u8).bitor(chunk[1].shr(4_u8)) as u32)
88                     .bitand(LOW_SIX_BITS) as usize];
89             // bottom 4 bits of [1], with the 2 bottom bits as zero
90             output[output_index + 2] =
91                 self.encode_table[(chunk[1].shl(2_u8) as u32).bitand(LOW_SIX_BITS) as usize];
92 
93             output_index += 3;
94         } else if rem == 1 {
95             let byte = input[input_index];
96             output[output_index] = self.encode_table[byte.shr(2) as usize];
97             output[output_index + 1] =
98                 self.encode_table[(byte.shl(4_u8) as u32).bitand(LOW_SIX_BITS) as usize];
99             output_index += 2;
100         }
101 
102         output_index
103     }
104 
internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate105     fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate {
106         NaiveEstimate::new(input_len)
107     }
108 
internal_decode( &self, input: &[u8], output: &mut [u8], estimate: Self::DecodeEstimate, ) -> Result<DecodeMetadata, DecodeSliceError>109     fn internal_decode(
110         &self,
111         input: &[u8],
112         output: &mut [u8],
113         estimate: Self::DecodeEstimate,
114     ) -> Result<DecodeMetadata, DecodeSliceError> {
115         let complete_nonterminal_quads_len = general_purpose::decode::complete_quads_len(
116             input,
117             estimate.rem,
118             output.len(),
119             &self.decode_table,
120         )?;
121 
122         const BOTTOM_BYTE: u32 = 0xFF;
123 
124         for (chunk_index, chunk) in input[..complete_nonterminal_quads_len]
125             .chunks_exact(4)
126             .enumerate()
127         {
128             let input_index = chunk_index * 4;
129             let output_index = chunk_index * 3;
130 
131             let decoded_int: u32 = self.decode_byte_into_u32(input_index, chunk[0])?.shl(18)
132                 | self
133                     .decode_byte_into_u32(input_index + 1, chunk[1])?
134                     .shl(12)
135                 | self.decode_byte_into_u32(input_index + 2, chunk[2])?.shl(6)
136                 | self.decode_byte_into_u32(input_index + 3, chunk[3])?;
137 
138             output[output_index] = decoded_int.shr(16_u8).bitand(BOTTOM_BYTE) as u8;
139             output[output_index + 1] = decoded_int.shr(8_u8).bitand(BOTTOM_BYTE) as u8;
140             output[output_index + 2] = decoded_int.bitand(BOTTOM_BYTE) as u8;
141         }
142 
143         general_purpose::decode_suffix::decode_suffix(
144             input,
145             complete_nonterminal_quads_len,
146             output,
147             complete_nonterminal_quads_len / 4 * 3,
148             &self.decode_table,
149             self.config.decode_allow_trailing_bits,
150             self.config.decode_padding_mode,
151         )
152     }
153 
config(&self) -> &Self::Config154     fn config(&self) -> &Self::Config {
155         &self.config
156     }
157 }
158 
159 pub struct NaiveEstimate {
160     /// remainder from dividing input by `Naive::DECODE_CHUNK_SIZE`
161     rem: usize,
162     /// Length of input that is in complete `Naive::DECODE_CHUNK_SIZE`-length chunks
163     complete_chunk_len: usize,
164 }
165 
166 impl NaiveEstimate {
new(input_len: usize) -> Self167     fn new(input_len: usize) -> Self {
168         let rem = input_len % Naive::DECODE_INPUT_CHUNK_SIZE;
169         let complete_chunk_len = input_len - rem;
170 
171         Self {
172             rem,
173             complete_chunk_len,
174         }
175     }
176 }
177 
178 impl DecodeEstimate for NaiveEstimate {
decoded_len_estimate(&self) -> usize179     fn decoded_len_estimate(&self) -> usize {
180         ((self.complete_chunk_len / 4) + ((self.rem > 0) as usize)) * 3
181     }
182 }
183 
184 #[derive(Clone, Copy, Debug)]
185 pub struct NaiveConfig {
186     pub encode_padding: bool,
187     pub decode_allow_trailing_bits: bool,
188     pub decode_padding_mode: DecodePaddingMode,
189 }
190 
191 impl Config for NaiveConfig {
encode_padding(&self) -> bool192     fn encode_padding(&self) -> bool {
193         self.encode_padding
194     }
195 }
196