1 // Copyright 2022 The ChromiumOS Authors 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 use std::ffi::c_void; 6 use std::ffi::OsString; 7 use std::io; 8 use std::ptr; 9 10 use winapi::shared::minwindef::ULONG; 11 use winapi::um::winnt::PVOID; 12 13 use super::unicode_string_to_os_string; 14 15 // Required for Windows API FFI bindings, as the names of the FFI structs and 16 // functions get called out by the linter. 17 #[allow(non_upper_case_globals)] 18 #[allow(non_camel_case_types)] 19 #[allow(non_snake_case)] 20 #[allow(dead_code)] 21 mod dll_notification_sys { 22 use std::io; 23 24 use winapi::shared::minwindef::ULONG; 25 use winapi::shared::ntdef::NTSTATUS; 26 use winapi::shared::ntdef::PCUNICODE_STRING; 27 use winapi::shared::ntstatus::STATUS_SUCCESS; 28 use winapi::um::libloaderapi::GetModuleHandleA; 29 use winapi::um::libloaderapi::GetProcAddress; 30 use winapi::um::winnt::CHAR; 31 use winapi::um::winnt::PVOID; 32 33 #[repr(C)] 34 pub union _LDR_DLL_NOTIFICATION_DATA { 35 pub Loaded: LDR_DLL_LOADED_NOTIFICATION_DATA, 36 pub Unloaded: LDR_DLL_UNLOADED_NOTIFICATION_DATA, 37 } 38 pub type LDR_DLL_NOTIFICATION_DATA = _LDR_DLL_NOTIFICATION_DATA; 39 pub type PLDR_DLL_NOTIFICATION_DATA = *mut LDR_DLL_NOTIFICATION_DATA; 40 41 #[repr(C)] 42 #[derive(Debug, Copy, Clone)] 43 pub struct _LDR_DLL_LOADED_NOTIFICATION_DATA { 44 pub Flags: ULONG, // Reserved. 45 pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module. 46 pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module. 47 pub DllBase: PVOID, // A pointer to the base address for the DLL in memory. 48 pub SizeOfImage: ULONG, // The size of the DLL image, in bytes. 49 } 50 pub type LDR_DLL_LOADED_NOTIFICATION_DATA = _LDR_DLL_LOADED_NOTIFICATION_DATA; 51 pub type PLDR_DLL_LOADED_NOTIFICATION_DATA = *mut LDR_DLL_LOADED_NOTIFICATION_DATA; 52 53 #[repr(C)] 54 #[derive(Debug, Copy, Clone)] 55 pub struct _LDR_DLL_UNLOADED_NOTIFICATION_DATA { 56 pub Flags: ULONG, // Reserved. 57 pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module. 58 pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module. 59 pub DllBase: PVOID, // A pointer to the base address for the DLL in memory. 60 pub SizeOfImage: ULONG, // The size of the DLL image, in bytes. 61 } 62 pub type LDR_DLL_UNLOADED_NOTIFICATION_DATA = _LDR_DLL_UNLOADED_NOTIFICATION_DATA; 63 pub type PLDR_DLL_UNLOADED_NOTIFICATION_DATA = *mut LDR_DLL_UNLOADED_NOTIFICATION_DATA; 64 65 pub const LDR_DLL_NOTIFICATION_REASON_LOADED: ULONG = 1; 66 pub const LDR_DLL_NOTIFICATION_REASON_UNLOADED: ULONG = 2; 67 68 const NTDLL: &[u8] = b"ntdll\0"; 69 const LDR_REGISTER_DLL_NOTIFICATION: &[u8] = b"LdrRegisterDllNotification\0"; 70 const LDR_UNREGISTER_DLL_NOTIFICATION: &[u8] = b"LdrUnregisterDllNotification\0"; 71 72 pub type LdrDllNotification = unsafe extern "C" fn( 73 NotificationReason: ULONG, 74 NotificationData: PLDR_DLL_NOTIFICATION_DATA, 75 Context: PVOID, 76 ); 77 78 pub type FnLdrRegisterDllNotification = 79 unsafe extern "C" fn(ULONG, LdrDllNotification, PVOID, *mut PVOID) -> NTSTATUS; 80 pub type FnLdrUnregisterDllNotification = unsafe extern "C" fn(PVOID) -> NTSTATUS; 81 82 extern "C" { RtlNtStatusToDosError(Status: NTSTATUS) -> ULONG83 pub fn RtlNtStatusToDosError(Status: NTSTATUS) -> ULONG; 84 } 85 86 /// Wrapper for the NTDLL `LdrRegisterDllNotification` function. Dynamically 87 /// gets the address of the function and invokes the function with the given 88 /// arguments. 89 /// 90 /// # Safety 91 /// Unsafe as this function does not verify its arguments; the caller is 92 /// expected to verify the safety as if invoking the underlying C function. LdrRegisterDllNotification( Flags: ULONG, NotificationFunction: LdrDllNotification, Context: PVOID, Cookie: *mut PVOID, ) -> io::Result<()>93 pub unsafe fn LdrRegisterDllNotification( 94 Flags: ULONG, 95 NotificationFunction: LdrDllNotification, 96 Context: PVOID, 97 Cookie: *mut PVOID, 98 ) -> io::Result<()> { 99 let proc_addr = GetProcAddress( 100 /* hModule= */ 101 GetModuleHandleA(/* lpModuleName= */ NTDLL.as_ptr() as *const CHAR), 102 /* lpProcName= */ 103 LDR_REGISTER_DLL_NOTIFICATION.as_ptr() as *const CHAR, 104 ); 105 if proc_addr.is_null() { 106 return Err(std::io::Error::last_os_error()); 107 } 108 let ldr_register_dll_notification: FnLdrRegisterDllNotification = 109 std::mem::transmute(proc_addr); 110 let ret = ldr_register_dll_notification(Flags, NotificationFunction, Context, Cookie); 111 if ret != STATUS_SUCCESS { 112 return Err(io::Error::from_raw_os_error( 113 RtlNtStatusToDosError(/* Status= */ ret) as i32, 114 )); 115 }; 116 Ok(()) 117 } 118 119 /// Wrapper for the NTDLL `LdrUnregisterDllNotification` function. Dynamically 120 /// gets the address of the function and invokes the function with the given 121 /// arguments. 122 /// 123 /// # Safety 124 /// Unsafe as this function does not verify its arguments; the caller is 125 /// expected to verify the safety as if invoking the underlying C function. LdrUnregisterDllNotification(Cookie: PVOID) -> io::Result<()>126 pub unsafe fn LdrUnregisterDllNotification(Cookie: PVOID) -> io::Result<()> { 127 let proc_addr = GetProcAddress( 128 /* hModule= */ 129 GetModuleHandleA(/* lpModuleName= */ NTDLL.as_ptr() as *const CHAR), 130 /* lpProcName= */ 131 LDR_UNREGISTER_DLL_NOTIFICATION.as_ptr() as *const CHAR, 132 ); 133 if proc_addr.is_null() { 134 return Err(std::io::Error::last_os_error()); 135 } 136 let ldr_unregister_dll_notification: FnLdrUnregisterDllNotification = 137 std::mem::transmute(proc_addr); 138 let ret = ldr_unregister_dll_notification(Cookie); 139 if ret != STATUS_SUCCESS { 140 return Err(io::Error::from_raw_os_error( 141 RtlNtStatusToDosError(/* Status= */ ret) as i32, 142 )); 143 }; 144 Ok(()) 145 } 146 } 147 148 use dll_notification_sys::*; 149 150 #[derive(Debug)] 151 pub struct DllNotificationData { 152 pub full_dll_name: OsString, 153 pub base_dll_name: OsString, 154 } 155 156 /// Callback context wrapper for DLL load notification functions. 157 /// 158 /// This struct provides a wrapper for invoking a function-like type any time a 159 /// DLL is loaded in the current process. This is done in a type-safe way, 160 /// provided that users of this struct observe some safety invariants. 161 /// 162 /// # Safety 163 /// The struct instance must not be used once it has been registered as a 164 /// notification target. The callback function assumes that it has a mutable 165 /// reference to the struct instance. Only once the callback is unregistered is 166 /// it safe to re-use the struct instance. 167 struct CallbackContext<F1, F2> 168 where 169 F1: FnMut(DllNotificationData), 170 F2: FnMut(DllNotificationData), 171 { 172 loaded_callback: F1, 173 unloaded_callback: F2, 174 } 175 176 impl<F1, F2> CallbackContext<F1, F2> 177 where 178 F1: FnMut(DllNotificationData), 179 F2: FnMut(DllNotificationData), 180 { 181 /// Create a new `CallbackContext` with the two callback functions. Takes 182 /// two callbacks, a `loaded_callback` which is called when a DLL is 183 /// loaded, and `unloaded_callback` which is called when a DLL is unloaded. new(loaded_callback: F1, unloaded_callback: F2) -> Self184 pub fn new(loaded_callback: F1, unloaded_callback: F2) -> Self { 185 CallbackContext { 186 loaded_callback, 187 unloaded_callback, 188 } 189 } 190 191 /// Provides a notification function that can be passed to the 192 /// `LdrRegisterDllNotification` function. get_notification_function(&self) -> LdrDllNotification193 pub fn get_notification_function(&self) -> LdrDllNotification { 194 Self::notification_function 195 } 196 197 /// A notification function with C linkage. This function assumes that it 198 /// has exclusive access to the instance of the struct passed through the 199 /// `context` parameter. notification_function( notification_reason: ULONG, notification_data: PLDR_DLL_NOTIFICATION_DATA, context: PVOID, )200 extern "C" fn notification_function( 201 notification_reason: ULONG, 202 notification_data: PLDR_DLL_NOTIFICATION_DATA, 203 context: PVOID, 204 ) { 205 let callback_context = 206 // SAFETY: The DLLWatcher guarantees that the CallbackContext instance is not null and 207 // that we have exclusive access to it. 208 unsafe { (context as *mut Self).as_mut() }.expect("context was null"); 209 210 assert!(!notification_data.is_null()); 211 212 match notification_reason { 213 LDR_DLL_NOTIFICATION_REASON_LOADED => { 214 // SAFETY: We know that the LDR_DLL_NOTIFICATION_DATA union contains the 215 // LDR_DLL_LOADED_NOTIFICATION_DATA because we got 216 // LDR_DLL_NOTIFICATION_REASON_LOADED as the notification reason. 217 let loaded = unsafe { &mut (*notification_data).Loaded }; 218 219 assert!(!loaded.BaseDllName.is_null()); 220 221 // SAFETY: We assert that the pointer is not null and expect that the OS has 222 // provided a valid UNICODE_STRING struct. 223 let base_dll_name = unsafe { unicode_string_to_os_string(&*loaded.BaseDllName) }; 224 225 assert!(!loaded.FullDllName.is_null()); 226 227 // SAFETY: We assert that the pointer is not null and expect that the OS has 228 // provided a valid UNICODE_STRING struct. 229 let full_dll_name = unsafe { unicode_string_to_os_string(&*loaded.FullDllName) }; 230 231 (callback_context.loaded_callback)(DllNotificationData { 232 base_dll_name, 233 full_dll_name, 234 }); 235 } 236 LDR_DLL_NOTIFICATION_REASON_UNLOADED => { 237 // SAFETY: We know that the LDR_DLL_NOTIFICATION_DATA union contains the 238 // LDR_DLL_UNLOADED_NOTIFICATION_DATA because we got 239 // LDR_DLL_NOTIFICATION_REASON_UNLOADED as the notification reason. 240 let unloaded = unsafe { &mut (*notification_data).Unloaded }; 241 242 assert!(!unloaded.BaseDllName.is_null()); 243 244 // SAFETY: We assert that the pointer is not null and expect that the OS has 245 // provided a valid UNICODE_STRING struct. 246 let base_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.BaseDllName) }; 247 248 assert!(!unloaded.FullDllName.is_null()); 249 250 // SAFETY: We assert that the pointer is not null and expect that the OS has 251 // provided a valid UNICODE_STRING struct. 252 let full_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.FullDllName) }; 253 254 (callback_context.unloaded_callback)(DllNotificationData { 255 base_dll_name, 256 full_dll_name, 257 }) 258 } 259 n => panic!("invalid value \"{}\" for dll notification reason", n), 260 } 261 } 262 } 263 264 /// DLL watcher for monitoring DLL loads/unloads. 265 /// 266 /// Provides a method to invoke a function-like type any time a DLL 267 /// is loaded or unloaded in the current process. 268 pub struct DllWatcher<F1, F2> 269 where 270 F1: FnMut(DllNotificationData), 271 F2: FnMut(DllNotificationData), 272 { 273 context: Box<CallbackContext<F1, F2>>, 274 cookie: Option<ptr::NonNull<c_void>>, 275 } 276 277 impl<F1, F2> DllWatcher<F1, F2> 278 where 279 F1: FnMut(DllNotificationData), 280 F2: FnMut(DllNotificationData), 281 { 282 /// Create a new `DllWatcher` with the two callback functions. Takes two 283 /// callbacks, a `loaded_callback` which is called when a DLL is loaded, 284 /// and `unloaded_callback` which is called when a DLL is unloaded. new(loaded_callback: F1, unloaded_callback: F2) -> io::Result<Self>285 pub fn new(loaded_callback: F1, unloaded_callback: F2) -> io::Result<Self> { 286 let mut watcher = Self { 287 context: Box::new(CallbackContext::new(loaded_callback, unloaded_callback)), 288 cookie: None, 289 }; 290 let mut cookie: PVOID = ptr::null_mut(); 291 // SAFETY: We guarantee that the notification function that we register will have exclusive 292 // access to the context. 293 unsafe { 294 LdrRegisterDllNotification( 295 /* Flags= */ 0, 296 /* NotificationFunction= */ watcher.context.get_notification_function(), 297 /* Context= */ 298 &mut *watcher.context as *mut CallbackContext<F1, F2> as PVOID, 299 /* Cookie= */ &mut cookie as *mut PVOID, 300 )? 301 }; 302 watcher.cookie = ptr::NonNull::new(cookie); 303 Ok(watcher) 304 } 305 unregister_dll_notification(&mut self) -> io::Result<()>306 fn unregister_dll_notification(&mut self) -> io::Result<()> { 307 if let Some(c) = self.cookie.take() { 308 // SAFETY: We guarantee that `Cookie` was previously initialized. 309 unsafe { 310 LdrUnregisterDllNotification(/* Cookie= */ c.as_ptr() as PVOID)? 311 } 312 } 313 314 Ok(()) 315 } 316 } 317 318 impl<F1, F2> Drop for DllWatcher<F1, F2> 319 where 320 F1: FnMut(DllNotificationData), 321 F2: FnMut(DllNotificationData), 322 { drop(&mut self)323 fn drop(&mut self) { 324 self.unregister_dll_notification() 325 .expect("error unregistering dll notification"); 326 } 327 } 328 329 #[cfg(test)] 330 mod tests { 331 use std::collections::HashSet; 332 use std::ffi::CString; 333 use std::io; 334 335 use winapi::shared::minwindef::FALSE; 336 use winapi::shared::minwindef::TRUE; 337 use winapi::um::handleapi::CloseHandle; 338 use winapi::um::libloaderapi::FreeLibrary; 339 use winapi::um::libloaderapi::LoadLibraryA; 340 use winapi::um::synchapi::CreateEventA; 341 use winapi::um::synchapi::SetEvent; 342 use winapi::um::synchapi::WaitForSingleObject; 343 use winapi::um::winbase::WAIT_OBJECT_0; 344 345 use super::*; 346 347 // Arbitrarily chosen DLLs for load/unload test. Chosen because they're 348 // hopefully esoteric enough that they're probably not already loaded in 349 // the process so we can test load/unload notifications. 350 // 351 // Using a single DLL can lead to flakiness; since the tests are run in the 352 // same process, it can be hard to rely on the OS to clean up the DLL loaded 353 // by one test before the other test runs. Using a different DLL makes the 354 // tests more independent. 355 const TEST_DLL_NAME_1: &str = "Imagehlp.dll"; 356 const TEST_DLL_NAME_2: &str = "dbghelp.dll"; 357 358 #[test] load_dll()359 fn load_dll() { 360 let test_dll_name = CString::new(TEST_DLL_NAME_1).expect("failed to create CString"); 361 let mut loaded_dlls: HashSet<OsString> = HashSet::new(); 362 let h_module = { 363 let _watcher = DllWatcher::new( 364 |data| { 365 loaded_dlls.insert(data.base_dll_name); 366 }, 367 |_data| (), 368 ) 369 .expect("failed to create DllWatcher"); 370 // SAFETY: We pass a valid C string in to the function. 371 unsafe { LoadLibraryA(test_dll_name.as_ptr()) } 372 }; 373 assert!( 374 !h_module.is_null(), 375 "failed to load {}: {}", 376 TEST_DLL_NAME_1, 377 io::Error::last_os_error() 378 ); 379 assert!( 380 !loaded_dlls.is_empty(), 381 "no DLL loads recorded by DLL watcher" 382 ); 383 assert!( 384 loaded_dlls.contains::<OsString>(&(TEST_DLL_NAME_1.to_owned().into())), 385 "{} load wasn't recorded by DLL watcher", 386 TEST_DLL_NAME_1 387 ); 388 // SAFETY: We initialized h_module with a LoadLibraryA call. 389 let success = unsafe { FreeLibrary(h_module) } > 0; 390 assert!( 391 success, 392 "failed to free {}: {}", 393 TEST_DLL_NAME_1, 394 io::Error::last_os_error(), 395 ) 396 } 397 398 #[test] unload_dll()399 fn unload_dll() { 400 let mut unloaded_dlls: HashSet<OsString> = HashSet::new(); 401 let event = 402 // SAFETY: No pointers are passed. The handle may leak if the test fails. 403 unsafe { CreateEventA(std::ptr::null_mut(), TRUE, FALSE, std::ptr::null_mut()) }; 404 assert!( 405 !event.is_null(), 406 "failed to create event; event was NULL: {}", 407 io::Error::last_os_error() 408 ); 409 { 410 let test_dll_name = CString::new(TEST_DLL_NAME_2).expect("failed to create CString"); 411 let _watcher = DllWatcher::new( 412 |_data| (), 413 |data| { 414 unloaded_dlls.insert(data.base_dll_name); 415 // SAFETY: We assert that the event is valid above. 416 unsafe { SetEvent(event) }; 417 }, 418 ) 419 .expect("failed to create DllWatcher"); 420 // SAFETY: We pass a valid C string in to the function. 421 let h_module = unsafe { LoadLibraryA(test_dll_name.as_ptr()) }; 422 assert!( 423 !h_module.is_null(), 424 "failed to load {}: {}", 425 TEST_DLL_NAME_2, 426 io::Error::last_os_error() 427 ); 428 // SAFETY: We initialized h_module with a LoadLibraryA call. 429 let success = unsafe { FreeLibrary(h_module) } > 0; 430 assert!( 431 success, 432 "failed to free {}: {}", 433 TEST_DLL_NAME_2, 434 io::Error::last_os_error(), 435 ) 436 }; 437 // SAFETY: We assert that the event is valid above. 438 assert_eq!(unsafe { WaitForSingleObject(event, 5000) }, WAIT_OBJECT_0); 439 assert!( 440 !unloaded_dlls.is_empty(), 441 "no DLL unloads recorded by DLL watcher" 442 ); 443 assert!( 444 unloaded_dlls.contains::<OsString>(&(TEST_DLL_NAME_2.to_owned().into())), 445 "{} unload wasn't recorded by DLL watcher", 446 TEST_DLL_NAME_2 447 ); 448 // SAFETY: We assert that the event is valid above. 449 unsafe { CloseHandle(event) }; 450 } 451 } 452