1 use crate::{
2     engine::{general_purpose::INVALID_VALUE, DecodeMetadata, DecodePaddingMode},
3     DecodeError, DecodeSliceError, PAD_BYTE,
4 };
5 
6 /// Decode the last 0-4 bytes, checking for trailing set bits and padding per the provided
7 /// parameters.
8 ///
9 /// Returns the decode metadata representing the total number of bytes decoded, including the ones
10 /// indicated as already written by `output_index`.
decode_suffix( input: &[u8], input_index: usize, output: &mut [u8], mut output_index: usize, decode_table: &[u8; 256], decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, ) -> Result<DecodeMetadata, DecodeSliceError>11 pub(crate) fn decode_suffix(
12     input: &[u8],
13     input_index: usize,
14     output: &mut [u8],
15     mut output_index: usize,
16     decode_table: &[u8; 256],
17     decode_allow_trailing_bits: bool,
18     padding_mode: DecodePaddingMode,
19 ) -> Result<DecodeMetadata, DecodeSliceError> {
20     debug_assert!((input.len() - input_index) <= 4);
21 
22     // Decode any leftovers that might not be a complete input chunk of 4 bytes.
23     // Use a u32 as a stack-resident 4 byte buffer.
24     let mut morsels_in_leftover = 0;
25     let mut padding_bytes_count = 0;
26     // offset from input_index
27     let mut first_padding_offset: usize = 0;
28     let mut last_symbol = 0_u8;
29     let mut morsels = [0_u8; 4];
30 
31     for (leftover_index, &b) in input[input_index..].iter().enumerate() {
32         // '=' padding
33         if b == PAD_BYTE {
34             // There can be bad padding bytes in a few ways:
35             // 1 - Padding with non-padding characters after it
36             // 2 - Padding after zero or one characters in the current quad (should only
37             //     be after 2 or 3 chars)
38             // 3 - More than two characters of padding. If 3 or 4 padding chars
39             //     are in the same quad, that implies it will be caught by #2.
40             //     If it spreads from one quad to another, it will be an invalid byte
41             //     in the first quad.
42             // 4 - Non-canonical padding -- 1 byte when it should be 2, etc.
43             //     Per config, non-canonical but still functional non- or partially-padded base64
44             //     may be treated as an error condition.
45 
46             if leftover_index < 2 {
47                 // Check for error #2.
48                 // Either the previous byte was padding, in which case we would have already hit
49                 // this case, or it wasn't, in which case this is the first such error.
50                 debug_assert!(
51                     leftover_index == 0 || (leftover_index == 1 && padding_bytes_count == 0)
52                 );
53                 let bad_padding_index = input_index + leftover_index;
54                 return Err(DecodeError::InvalidByte(bad_padding_index, b).into());
55             }
56 
57             if padding_bytes_count == 0 {
58                 first_padding_offset = leftover_index;
59             }
60 
61             padding_bytes_count += 1;
62             continue;
63         }
64 
65         // Check for case #1.
66         // To make '=' handling consistent with the main loop, don't allow
67         // non-suffix '=' in trailing chunk either. Report error as first
68         // erroneous padding.
69         if padding_bytes_count > 0 {
70             return Err(
71                 DecodeError::InvalidByte(input_index + first_padding_offset, PAD_BYTE).into(),
72             );
73         }
74 
75         last_symbol = b;
76 
77         // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
78         // Pack the leftovers from left to right.
79         let morsel = decode_table[b as usize];
80         if morsel == INVALID_VALUE {
81             return Err(DecodeError::InvalidByte(input_index + leftover_index, b).into());
82         }
83 
84         morsels[morsels_in_leftover] = morsel;
85         morsels_in_leftover += 1;
86     }
87 
88     // If there was 1 trailing byte, and it was valid, and we got to this point without hitting
89     // an invalid byte, now we can report invalid length
90     if !input.is_empty() && morsels_in_leftover < 2 {
91         return Err(DecodeError::InvalidLength(input_index + morsels_in_leftover).into());
92     }
93 
94     match padding_mode {
95         DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ }
96         DecodePaddingMode::RequireCanonical => {
97             // allow empty input
98             if (padding_bytes_count + morsels_in_leftover) % 4 != 0 {
99                 return Err(DecodeError::InvalidPadding.into());
100             }
101         }
102         DecodePaddingMode::RequireNone => {
103             if padding_bytes_count > 0 {
104                 // check at the end to make sure we let the cases of padding that should be InvalidByte
105                 // get hit
106                 return Err(DecodeError::InvalidPadding.into());
107             }
108         }
109     }
110 
111     // When encoding 1 trailing byte (e.g. 0xFF), 2 base64 bytes ("/w") are needed.
112     // / is the symbol for 63 (0x3F, bottom 6 bits all set) and w is 48 (0x30, top 2 bits
113     // of bottom 6 bits set).
114     // When decoding two symbols back to one trailing byte, any final symbol higher than
115     // w would still decode to the original byte because we only care about the top two
116     // bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a
117     // mask based on how many bits are used for just the canonical encoding, and optionally
118     // error if any other bits are set. In the example of one encoded byte -> 2 symbols,
119     // 2 symbols can technically encode 12 bits, but the last 4 are non-canonical, and
120     // useless since there are no more symbols to provide the necessary 4 additional bits
121     // to finish the second original byte.
122 
123     let leftover_bytes_to_append = morsels_in_leftover * 6 / 8;
124     // Put the up to 6 complete bytes as the high bytes.
125     // Gain a couple percent speedup from nudging these ORs to use more ILP with a two-way split.
126     let mut leftover_num = (u32::from(morsels[0]) << 26)
127         | (u32::from(morsels[1]) << 20)
128         | (u32::from(morsels[2]) << 14)
129         | (u32::from(morsels[3]) << 8);
130 
131     // if there are bits set outside the bits we care about, last symbol encodes trailing bits that
132     // will not be included in the output
133     let mask = !0_u32 >> (leftover_bytes_to_append * 8);
134     if !decode_allow_trailing_bits && (leftover_num & mask) != 0 {
135         // last morsel is at `morsels_in_leftover` - 1
136         return Err(DecodeError::InvalidLastSymbol(
137             input_index + morsels_in_leftover - 1,
138             last_symbol,
139         )
140         .into());
141     }
142 
143     // Strangely, this approach benchmarks better than writing bytes one at a time,
144     // or copy_from_slice into output.
145     for _ in 0..leftover_bytes_to_append {
146         let hi_byte = (leftover_num >> 24) as u8;
147         leftover_num <<= 8;
148         *output
149             .get_mut(output_index)
150             .ok_or(DecodeSliceError::OutputSliceTooSmall)? = hi_byte;
151         output_index += 1;
152     }
153 
154     Ok(DecodeMetadata::new(
155         output_index,
156         if padding_bytes_count > 0 {
157             Some(input_index + first_padding_offset)
158         } else {
159             None
160         },
161     ))
162 }
163