// Copyright 2022 The ChromiumOS Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. use std::ffi::c_void; use std::ffi::OsString; use std::io; use std::ptr; use winapi::shared::minwindef::ULONG; use winapi::um::winnt::PVOID; use super::unicode_string_to_os_string; // Required for Windows API FFI bindings, as the names of the FFI structs and // functions get called out by the linter. #[allow(non_upper_case_globals)] #[allow(non_camel_case_types)] #[allow(non_snake_case)] #[allow(dead_code)] mod dll_notification_sys { use std::io; use winapi::shared::minwindef::ULONG; use winapi::shared::ntdef::NTSTATUS; use winapi::shared::ntdef::PCUNICODE_STRING; use winapi::shared::ntstatus::STATUS_SUCCESS; use winapi::um::libloaderapi::GetModuleHandleA; use winapi::um::libloaderapi::GetProcAddress; use winapi::um::winnt::CHAR; use winapi::um::winnt::PVOID; #[repr(C)] pub union _LDR_DLL_NOTIFICATION_DATA { pub Loaded: LDR_DLL_LOADED_NOTIFICATION_DATA, pub Unloaded: LDR_DLL_UNLOADED_NOTIFICATION_DATA, } pub type LDR_DLL_NOTIFICATION_DATA = _LDR_DLL_NOTIFICATION_DATA; pub type PLDR_DLL_NOTIFICATION_DATA = *mut LDR_DLL_NOTIFICATION_DATA; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct _LDR_DLL_LOADED_NOTIFICATION_DATA { pub Flags: ULONG, // Reserved. pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module. pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module. pub DllBase: PVOID, // A pointer to the base address for the DLL in memory. pub SizeOfImage: ULONG, // The size of the DLL image, in bytes. } pub type LDR_DLL_LOADED_NOTIFICATION_DATA = _LDR_DLL_LOADED_NOTIFICATION_DATA; pub type PLDR_DLL_LOADED_NOTIFICATION_DATA = *mut LDR_DLL_LOADED_NOTIFICATION_DATA; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct _LDR_DLL_UNLOADED_NOTIFICATION_DATA { pub Flags: ULONG, // Reserved. pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module. pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module. pub DllBase: PVOID, // A pointer to the base address for the DLL in memory. pub SizeOfImage: ULONG, // The size of the DLL image, in bytes. } pub type LDR_DLL_UNLOADED_NOTIFICATION_DATA = _LDR_DLL_UNLOADED_NOTIFICATION_DATA; pub type PLDR_DLL_UNLOADED_NOTIFICATION_DATA = *mut LDR_DLL_UNLOADED_NOTIFICATION_DATA; pub const LDR_DLL_NOTIFICATION_REASON_LOADED: ULONG = 1; pub const LDR_DLL_NOTIFICATION_REASON_UNLOADED: ULONG = 2; const NTDLL: &[u8] = b"ntdll\0"; const LDR_REGISTER_DLL_NOTIFICATION: &[u8] = b"LdrRegisterDllNotification\0"; const LDR_UNREGISTER_DLL_NOTIFICATION: &[u8] = b"LdrUnregisterDllNotification\0"; pub type LdrDllNotification = unsafe extern "C" fn( NotificationReason: ULONG, NotificationData: PLDR_DLL_NOTIFICATION_DATA, Context: PVOID, ); pub type FnLdrRegisterDllNotification = unsafe extern "C" fn(ULONG, LdrDllNotification, PVOID, *mut PVOID) -> NTSTATUS; pub type FnLdrUnregisterDllNotification = unsafe extern "C" fn(PVOID) -> NTSTATUS; extern "C" { pub fn RtlNtStatusToDosError(Status: NTSTATUS) -> ULONG; } /// Wrapper for the NTDLL `LdrRegisterDllNotification` function. Dynamically /// gets the address of the function and invokes the function with the given /// arguments. /// /// # Safety /// Unsafe as this function does not verify its arguments; the caller is /// expected to verify the safety as if invoking the underlying C function. pub unsafe fn LdrRegisterDllNotification( Flags: ULONG, NotificationFunction: LdrDllNotification, Context: PVOID, Cookie: *mut PVOID, ) -> io::Result<()> { let proc_addr = GetProcAddress( /* hModule= */ GetModuleHandleA(/* lpModuleName= */ NTDLL.as_ptr() as *const CHAR), /* lpProcName= */ LDR_REGISTER_DLL_NOTIFICATION.as_ptr() as *const CHAR, ); if proc_addr.is_null() { return Err(std::io::Error::last_os_error()); } let ldr_register_dll_notification: FnLdrRegisterDllNotification = std::mem::transmute(proc_addr); let ret = ldr_register_dll_notification(Flags, NotificationFunction, Context, Cookie); if ret != STATUS_SUCCESS { return Err(io::Error::from_raw_os_error( RtlNtStatusToDosError(/* Status= */ ret) as i32, )); }; Ok(()) } /// Wrapper for the NTDLL `LdrUnregisterDllNotification` function. Dynamically /// gets the address of the function and invokes the function with the given /// arguments. /// /// # Safety /// Unsafe as this function does not verify its arguments; the caller is /// expected to verify the safety as if invoking the underlying C function. pub unsafe fn LdrUnregisterDllNotification(Cookie: PVOID) -> io::Result<()> { let proc_addr = GetProcAddress( /* hModule= */ GetModuleHandleA(/* lpModuleName= */ NTDLL.as_ptr() as *const CHAR), /* lpProcName= */ LDR_UNREGISTER_DLL_NOTIFICATION.as_ptr() as *const CHAR, ); if proc_addr.is_null() { return Err(std::io::Error::last_os_error()); } let ldr_unregister_dll_notification: FnLdrUnregisterDllNotification = std::mem::transmute(proc_addr); let ret = ldr_unregister_dll_notification(Cookie); if ret != STATUS_SUCCESS { return Err(io::Error::from_raw_os_error( RtlNtStatusToDosError(/* Status= */ ret) as i32, )); }; Ok(()) } } use dll_notification_sys::*; #[derive(Debug)] pub struct DllNotificationData { pub full_dll_name: OsString, pub base_dll_name: OsString, } /// Callback context wrapper for DLL load notification functions. /// /// This struct provides a wrapper for invoking a function-like type any time a /// DLL is loaded in the current process. This is done in a type-safe way, /// provided that users of this struct observe some safety invariants. /// /// # Safety /// The struct instance must not be used once it has been registered as a /// notification target. The callback function assumes that it has a mutable /// reference to the struct instance. Only once the callback is unregistered is /// it safe to re-use the struct instance. struct CallbackContext where F1: FnMut(DllNotificationData), F2: FnMut(DllNotificationData), { loaded_callback: F1, unloaded_callback: F2, } impl CallbackContext where F1: FnMut(DllNotificationData), F2: FnMut(DllNotificationData), { /// Create a new `CallbackContext` with the two callback functions. Takes /// two callbacks, a `loaded_callback` which is called when a DLL is /// loaded, and `unloaded_callback` which is called when a DLL is unloaded. pub fn new(loaded_callback: F1, unloaded_callback: F2) -> Self { CallbackContext { loaded_callback, unloaded_callback, } } /// Provides a notification function that can be passed to the /// `LdrRegisterDllNotification` function. pub fn get_notification_function(&self) -> LdrDllNotification { Self::notification_function } /// A notification function with C linkage. This function assumes that it /// has exclusive access to the instance of the struct passed through the /// `context` parameter. extern "C" fn notification_function( notification_reason: ULONG, notification_data: PLDR_DLL_NOTIFICATION_DATA, context: PVOID, ) { let callback_context = // SAFETY: The DLLWatcher guarantees that the CallbackContext instance is not null and // that we have exclusive access to it. unsafe { (context as *mut Self).as_mut() }.expect("context was null"); assert!(!notification_data.is_null()); match notification_reason { LDR_DLL_NOTIFICATION_REASON_LOADED => { // SAFETY: We know that the LDR_DLL_NOTIFICATION_DATA union contains the // LDR_DLL_LOADED_NOTIFICATION_DATA because we got // LDR_DLL_NOTIFICATION_REASON_LOADED as the notification reason. let loaded = unsafe { &mut (*notification_data).Loaded }; assert!(!loaded.BaseDllName.is_null()); // SAFETY: We assert that the pointer is not null and expect that the OS has // provided a valid UNICODE_STRING struct. let base_dll_name = unsafe { unicode_string_to_os_string(&*loaded.BaseDllName) }; assert!(!loaded.FullDllName.is_null()); // SAFETY: We assert that the pointer is not null and expect that the OS has // provided a valid UNICODE_STRING struct. let full_dll_name = unsafe { unicode_string_to_os_string(&*loaded.FullDllName) }; (callback_context.loaded_callback)(DllNotificationData { base_dll_name, full_dll_name, }); } LDR_DLL_NOTIFICATION_REASON_UNLOADED => { // SAFETY: We know that the LDR_DLL_NOTIFICATION_DATA union contains the // LDR_DLL_UNLOADED_NOTIFICATION_DATA because we got // LDR_DLL_NOTIFICATION_REASON_UNLOADED as the notification reason. let unloaded = unsafe { &mut (*notification_data).Unloaded }; assert!(!unloaded.BaseDllName.is_null()); // SAFETY: We assert that the pointer is not null and expect that the OS has // provided a valid UNICODE_STRING struct. let base_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.BaseDllName) }; assert!(!unloaded.FullDllName.is_null()); // SAFETY: We assert that the pointer is not null and expect that the OS has // provided a valid UNICODE_STRING struct. let full_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.FullDllName) }; (callback_context.unloaded_callback)(DllNotificationData { base_dll_name, full_dll_name, }) } n => panic!("invalid value \"{}\" for dll notification reason", n), } } } /// DLL watcher for monitoring DLL loads/unloads. /// /// Provides a method to invoke a function-like type any time a DLL /// is loaded or unloaded in the current process. pub struct DllWatcher where F1: FnMut(DllNotificationData), F2: FnMut(DllNotificationData), { context: Box>, cookie: Option>, } impl DllWatcher where F1: FnMut(DllNotificationData), F2: FnMut(DllNotificationData), { /// Create a new `DllWatcher` with the two callback functions. Takes two /// callbacks, a `loaded_callback` which is called when a DLL is loaded, /// and `unloaded_callback` which is called when a DLL is unloaded. pub fn new(loaded_callback: F1, unloaded_callback: F2) -> io::Result { let mut watcher = Self { context: Box::new(CallbackContext::new(loaded_callback, unloaded_callback)), cookie: None, }; let mut cookie: PVOID = ptr::null_mut(); // SAFETY: We guarantee that the notification function that we register will have exclusive // access to the context. unsafe { LdrRegisterDllNotification( /* Flags= */ 0, /* NotificationFunction= */ watcher.context.get_notification_function(), /* Context= */ &mut *watcher.context as *mut CallbackContext as PVOID, /* Cookie= */ &mut cookie as *mut PVOID, )? }; watcher.cookie = ptr::NonNull::new(cookie); Ok(watcher) } fn unregister_dll_notification(&mut self) -> io::Result<()> { if let Some(c) = self.cookie.take() { // SAFETY: We guarantee that `Cookie` was previously initialized. unsafe { LdrUnregisterDllNotification(/* Cookie= */ c.as_ptr() as PVOID)? } } Ok(()) } } impl Drop for DllWatcher where F1: FnMut(DllNotificationData), F2: FnMut(DllNotificationData), { fn drop(&mut self) { self.unregister_dll_notification() .expect("error unregistering dll notification"); } } #[cfg(test)] mod tests { use std::collections::HashSet; use std::ffi::CString; use std::io; use winapi::shared::minwindef::FALSE; use winapi::shared::minwindef::TRUE; use winapi::um::handleapi::CloseHandle; use winapi::um::libloaderapi::FreeLibrary; use winapi::um::libloaderapi::LoadLibraryA; use winapi::um::synchapi::CreateEventA; use winapi::um::synchapi::SetEvent; use winapi::um::synchapi::WaitForSingleObject; use winapi::um::winbase::WAIT_OBJECT_0; use super::*; // Arbitrarily chosen DLLs for load/unload test. Chosen because they're // hopefully esoteric enough that they're probably not already loaded in // the process so we can test load/unload notifications. // // Using a single DLL can lead to flakiness; since the tests are run in the // same process, it can be hard to rely on the OS to clean up the DLL loaded // by one test before the other test runs. Using a different DLL makes the // tests more independent. const TEST_DLL_NAME_1: &str = "Imagehlp.dll"; const TEST_DLL_NAME_2: &str = "dbghelp.dll"; #[test] fn load_dll() { let test_dll_name = CString::new(TEST_DLL_NAME_1).expect("failed to create CString"); let mut loaded_dlls: HashSet = HashSet::new(); let h_module = { let _watcher = DllWatcher::new( |data| { loaded_dlls.insert(data.base_dll_name); }, |_data| (), ) .expect("failed to create DllWatcher"); // SAFETY: We pass a valid C string in to the function. unsafe { LoadLibraryA(test_dll_name.as_ptr()) } }; assert!( !h_module.is_null(), "failed to load {}: {}", TEST_DLL_NAME_1, io::Error::last_os_error() ); assert!( !loaded_dlls.is_empty(), "no DLL loads recorded by DLL watcher" ); assert!( loaded_dlls.contains::(&(TEST_DLL_NAME_1.to_owned().into())), "{} load wasn't recorded by DLL watcher", TEST_DLL_NAME_1 ); // SAFETY: We initialized h_module with a LoadLibraryA call. let success = unsafe { FreeLibrary(h_module) } > 0; assert!( success, "failed to free {}: {}", TEST_DLL_NAME_1, io::Error::last_os_error(), ) } #[test] fn unload_dll() { let mut unloaded_dlls: HashSet = HashSet::new(); let event = // SAFETY: No pointers are passed. The handle may leak if the test fails. unsafe { CreateEventA(std::ptr::null_mut(), TRUE, FALSE, std::ptr::null_mut()) }; assert!( !event.is_null(), "failed to create event; event was NULL: {}", io::Error::last_os_error() ); { let test_dll_name = CString::new(TEST_DLL_NAME_2).expect("failed to create CString"); let _watcher = DllWatcher::new( |_data| (), |data| { unloaded_dlls.insert(data.base_dll_name); // SAFETY: We assert that the event is valid above. unsafe { SetEvent(event) }; }, ) .expect("failed to create DllWatcher"); // SAFETY: We pass a valid C string in to the function. let h_module = unsafe { LoadLibraryA(test_dll_name.as_ptr()) }; assert!( !h_module.is_null(), "failed to load {}: {}", TEST_DLL_NAME_2, io::Error::last_os_error() ); // SAFETY: We initialized h_module with a LoadLibraryA call. let success = unsafe { FreeLibrary(h_module) } > 0; assert!( success, "failed to free {}: {}", TEST_DLL_NAME_2, io::Error::last_os_error(), ) }; // SAFETY: We assert that the event is valid above. assert_eq!(unsafe { WaitForSingleObject(event, 5000) }, WAIT_OBJECT_0); assert!( !unloaded_dlls.is_empty(), "no DLL unloads recorded by DLL watcher" ); assert!( unloaded_dlls.contains::(&(TEST_DLL_NAME_2.to_owned().into())), "{} unload wasn't recorded by DLL watcher", TEST_DLL_NAME_2 ); // SAFETY: We assert that the event is valid above. unsafe { CloseHandle(event) }; } }