1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use std::collections::HashMap;
16 use std::ptr::null_mut;
17
18 use crypto_provider_default::CryptoProviderImpl as CryptoProvider;
19 use lazy_static::lazy_static;
20 use lock_adapter::NoPoisonMutex;
21 use rand::Rng;
22 use rand_chacha::rand_core::SeedableRng;
23 use rand_chacha::ChaCha20Rng;
24
25 #[cfg(not(feature = "std"))]
26 use lock_adapter::spin::Mutex;
27 #[cfg(feature = "std")]
28 use lock_adapter::stdlib::Mutex;
29
30 use ukey2_connections::{
31 D2DConnectionContextV1, D2DHandshakeContext, HandleMessageError, HandshakeImplementation,
32 InitiatorD2DHandshakeContext, NextProtocol, ServerD2DHandshakeContext,
33 };
34
35 #[repr(C)]
36 pub struct RustFFIByteArray {
37 ptr: *mut u8,
38 len: usize,
39 cap: usize,
40 }
41
42 impl RustFFIByteArray {
from_vec(vec: Vec<u8>) -> RustFFIByteArray43 fn from_vec(vec: Vec<u8>) -> RustFFIByteArray {
44 let mut vec = core::mem::ManuallyDrop::new(vec);
45 RustFFIByteArray { ptr: vec.as_mut_ptr(), len: vec.len(), cap: vec.capacity() }
46 }
47
into_vec(self) -> Option<Vec<u8>>48 unsafe fn into_vec(self) -> Option<Vec<u8>> {
49 if self.ptr.is_null() {
50 return None;
51 }
52 Some(Vec::from_raw_parts(self.ptr, self.len, self.cap))
53 }
54 }
55
56 #[repr(C)]
57 pub struct CFFIByteArray {
58 ptr: *mut u8,
59 len: usize,
60 }
61
62 #[repr(C)]
63 pub struct CMessageParseResult {
64 success: bool,
65 alert_to_send: RustFFIByteArray,
66 }
67
68 type D2DBox = Box<dyn D2DHandshakeContext>;
69 type ConnectionBox = Box<D2DConnectionContextV1>;
70
71 lazy_static! {
72 static ref HANDLE_MAPPING: Mutex<HashMap<u64, D2DBox>> = Mutex::new(HashMap::new());
73 static ref CONNECTION_HANDLE_MAPPING: Mutex<HashMap<u64, ConnectionBox>> =
74 Mutex::new(HashMap::new());
75 static ref RNG: Mutex<ChaCha20Rng> = Mutex::new(ChaCha20Rng::from_entropy());
76 }
77
generate_handle() -> u6478 fn generate_handle() -> u64 {
79 RNG.lock().gen()
80 }
81
insert_gen_handle(item: D2DBox) -> u6482 fn insert_gen_handle(item: D2DBox) -> u64 {
83 let handle = generate_handle();
84 HANDLE_MAPPING.lock().insert(handle, item);
85 handle
86 }
87
insert_conn_gen_handle(item: ConnectionBox) -> u6488 fn insert_conn_gen_handle(item: ConnectionBox) -> u64 {
89 let handle = generate_handle();
90 CONNECTION_HANDLE_MAPPING.lock().insert(handle, item);
91 handle
92 }
93
94 // Utilities
95 /// This function deallocates FFIByteArray instances allocated from Rust only.
96 /// NOTE: Any FFIByteArray instances deallocated by this function will no longer be in a guaranteed
97 /// usable state.
98 ///
99 /// # Safety
100 /// The array must have been allocated by a Rust function with the Rust allocator, e.g.
101 /// [get_next_handshake_message].
102 #[no_mangle]
rust_dealloc_ffi_byte_array(arr: RustFFIByteArray)103 pub unsafe extern "C" fn rust_dealloc_ffi_byte_array(arr: RustFFIByteArray) {
104 if let Some(vec) = arr.into_vec() {
105 core::mem::drop(vec);
106 }
107 }
108
109 // Common functions
110 #[no_mangle]
is_handshake_complete(handle: u64) -> bool111 pub extern "C" fn is_handshake_complete(handle: u64) -> bool {
112 HANDLE_MAPPING.lock().get(&handle).map_or(false, |ctx| ctx.is_handshake_complete())
113 }
114
115 #[no_mangle]
get_next_handshake_message(handle: u64) -> RustFFIByteArray116 pub extern "C" fn get_next_handshake_message(handle: u64) -> RustFFIByteArray {
117 // TODO: error handling
118 let opt_msg = HANDLE_MAPPING.lock().get(&handle).and_then(|c| c.get_next_handshake_message());
119 if let Some(msg) = opt_msg {
120 RustFFIByteArray::from_vec(msg)
121 } else {
122 RustFFIByteArray { ptr: null_mut(), len: usize::MAX, cap: usize::MAX }
123 }
124 }
125
126 /// # Safety
127 /// We treat msg as data, so we should never have an issue trying to execute it.
128 #[no_mangle]
parse_handshake_message( handle: u64, arr: CFFIByteArray, ) -> CMessageParseResult129 pub unsafe extern "C" fn parse_handshake_message(
130 handle: u64,
131 arr: CFFIByteArray,
132 ) -> CMessageParseResult {
133 let msg = std::slice::from_raw_parts(arr.ptr, arr.len);
134 let result = HANDLE_MAPPING.lock().get_mut(&handle).unwrap().handle_handshake_message(msg);
135 if let Err(error) = result {
136 match error {
137 HandleMessageError::InvalidState | HandleMessageError::BadMessage => {
138 log::error!("{:?}", error);
139 }
140 HandleMessageError::ErrorMessage(message) => {
141 return CMessageParseResult {
142 success: false,
143 alert_to_send: RustFFIByteArray::from_vec(message),
144 };
145 }
146 }
147 }
148 CMessageParseResult {
149 success: true,
150 alert_to_send: RustFFIByteArray { ptr: null_mut(), len: usize::MAX, cap: usize::MAX },
151 }
152 }
153
154 #[no_mangle]
get_verification_string(handle: u64, length: usize) -> RustFFIByteArray155 pub extern "C" fn get_verification_string(handle: u64, length: usize) -> RustFFIByteArray {
156 HANDLE_MAPPING
157 .lock()
158 .get(&handle)
159 .map(|h| {
160 let auth_vec = h
161 .to_completed_handshake()
162 .unwrap()
163 .auth_string::<CryptoProvider>()
164 .derive_vec(length)
165 .unwrap();
166 RustFFIByteArray::from_vec(auth_vec)
167 })
168 .unwrap()
169 }
170
171 #[no_mangle]
to_connection_context(handle: u64) -> u64172 pub extern "C" fn to_connection_context(handle: u64) -> u64 {
173 // TODO: error handling
174 let ctx = HANDLE_MAPPING
175 .lock()
176 .remove(&handle)
177 .map(move |mut ctx| {
178 let result = Box::new(ctx.to_connection_context().unwrap());
179 drop(ctx);
180 result
181 })
182 .unwrap();
183 insert_conn_gen_handle(ctx)
184 }
185
186 // Responder-specific functions
187 #[no_mangle]
responder_new() -> u64188 pub extern "C" fn responder_new() -> u64 {
189 let ctx = Box::new(ServerD2DHandshakeContext::<CryptoProvider>::new(
190 HandshakeImplementation::PublicKeyInProtobuf,
191 &[NextProtocol::Aes256CbcHmacSha256],
192 ));
193 insert_gen_handle(ctx)
194 }
195
196 // Initiator-specific functions
197
198 /// # Safety
199 /// We treat next_protocol as data, not as executable memory.
200 #[no_mangle]
initiator_new() -> u64201 pub extern "C" fn initiator_new() -> u64 {
202 let ctx = Box::new(InitiatorD2DHandshakeContext::<CryptoProvider>::new(
203 HandshakeImplementation::PublicKeyInProtobuf,
204 vec![NextProtocol::Aes256CbcHmacSha256],
205 ));
206 insert_gen_handle(ctx)
207 }
208
209 // Connection Context
210
211 /// # Safety
212 /// We treat msg and associated_data as data, not as executable memory.
213 /// associated_data and msg are slices so Rust won't try to do anything weird with allocation.
214 #[no_mangle]
encode_message_to_peer( handle: u64, msg: CFFIByteArray, associated_data: CFFIByteArray, ) -> RustFFIByteArray215 pub unsafe extern "C" fn encode_message_to_peer(
216 handle: u64,
217 msg: CFFIByteArray,
218 associated_data: CFFIByteArray,
219 ) -> RustFFIByteArray {
220 if msg.len == 0 {
221 return RustFFIByteArray { ptr: null_mut(), len: usize::MAX, cap: usize::MAX };
222 }
223 let msg = std::slice::from_raw_parts(msg.ptr, msg.len);
224 let associated_data = if !associated_data.ptr.is_null() {
225 Some(std::slice::from_raw_parts(associated_data.ptr, associated_data.len))
226 } else {
227 None
228 };
229 let ret = CONNECTION_HANDLE_MAPPING
230 .lock()
231 .get_mut(&handle)
232 .map(|c| c.encode_message_to_peer::<CryptoProvider, _>(msg, associated_data));
233 if let Some(msg) = ret {
234 RustFFIByteArray::from_vec(msg)
235 } else {
236 log::error!("Was unable to find handle!");
237 RustFFIByteArray { ptr: null_mut(), len: usize::MAX, cap: usize::MAX }
238 }
239 }
240
241 /// # Safety
242 /// We treat msg as data, not as executable memory.
243 #[no_mangle]
decode_message_from_peer( handle: u64, msg: CFFIByteArray, associated_data: CFFIByteArray, ) -> RustFFIByteArray244 pub unsafe extern "C" fn decode_message_from_peer(
245 handle: u64,
246 msg: CFFIByteArray,
247 associated_data: CFFIByteArray,
248 ) -> RustFFIByteArray {
249 if msg.len == 0 {
250 return RustFFIByteArray { ptr: null_mut(), len: usize::MAX, cap: usize::MAX };
251 }
252 let msg = std::slice::from_raw_parts(msg.ptr, msg.len);
253 let associated_data = if !associated_data.ptr.is_null() {
254 Some(std::slice::from_raw_parts(associated_data.ptr, associated_data.len))
255 } else {
256 None
257 };
258 let ret: Result<Vec<u8>, ukey2_connections::DecodeError> = CONNECTION_HANDLE_MAPPING
259 .lock()
260 .get_mut(&handle)
261 .unwrap()
262 .decode_message_from_peer::<CryptoProvider, _>(msg, associated_data);
263 if let Ok(decoded) = ret {
264 RustFFIByteArray::from_vec(decoded)
265 } else {
266 RustFFIByteArray { ptr: null_mut(), len: usize::MAX, cap: usize::MAX }
267 }
268 }
269
270 #[no_mangle]
get_session_unique(handle: u64) -> RustFFIByteArray271 pub extern "C" fn get_session_unique(handle: u64) -> RustFFIByteArray {
272 let session_unique_bytes = CONNECTION_HANDLE_MAPPING
273 .lock()
274 .get(&handle)
275 .unwrap()
276 .get_session_unique::<CryptoProvider>();
277 RustFFIByteArray::from_vec(session_unique_bytes)
278 }
279
280 #[no_mangle]
get_sequence_number_for_encoding(handle: u64) -> i32281 pub extern "C" fn get_sequence_number_for_encoding(handle: u64) -> i32 {
282 CONNECTION_HANDLE_MAPPING.lock().get(&handle).unwrap().get_sequence_number_for_encoding()
283 }
284
285 #[no_mangle]
get_sequence_number_for_decoding(handle: u64) -> i32286 pub extern "C" fn get_sequence_number_for_decoding(handle: u64) -> i32 {
287 CONNECTION_HANDLE_MAPPING.lock().get(&handle).unwrap().get_sequence_number_for_decoding()
288 }
289
290 #[no_mangle]
save_session(handle: u64) -> RustFFIByteArray291 pub extern "C" fn save_session(handle: u64) -> RustFFIByteArray {
292 let key = CONNECTION_HANDLE_MAPPING.lock().get(&handle).unwrap().save_session();
293 RustFFIByteArray::from_vec(key)
294 }
295
296 #[repr(i32)]
297 #[derive(Debug)]
298 pub enum Status {
299 Good,
300 Error,
301 }
302
303 #[repr(C)]
304 pub struct CD2DRestoreConnectionContextV1Result {
305 handle: u64,
306 status: Status,
307 }
308
309 /// # Safety
310 /// We error out if the length is incorrect (too large or too small) for restoring a session.
311 #[no_mangle]
from_saved_session( arr: CFFIByteArray, ) -> CD2DRestoreConnectionContextV1Result312 pub unsafe extern "C" fn from_saved_session(
313 arr: CFFIByteArray,
314 ) -> CD2DRestoreConnectionContextV1Result {
315 let saved_session = std::slice::from_raw_parts(arr.ptr, arr.len);
316 let ctx = D2DConnectionContextV1::from_saved_session::<CryptoProvider>(saved_session);
317 if let Ok(conn_ctx) = ctx {
318 let final_ctx = Box::new(conn_ctx);
319 CD2DRestoreConnectionContextV1Result {
320 handle: insert_conn_gen_handle(final_ctx),
321 status: Status::Good,
322 }
323 } else {
324 log::error!("failed to restore session with error {:?}", ctx.unwrap_err());
325 CD2DRestoreConnectionContextV1Result { handle: u64::MAX, status: Status::Error }
326 }
327 }
328