xref: /aosp_15_r20/external/crosvm/win_util/src/dll_notification.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::ffi::c_void;
6 use std::ffi::OsString;
7 use std::io;
8 use std::ptr;
9 
10 use winapi::shared::minwindef::ULONG;
11 use winapi::um::winnt::PVOID;
12 
13 use super::unicode_string_to_os_string;
14 
15 // Required for Windows API FFI bindings, as the names of the FFI structs and
16 // functions get called out by the linter.
17 #[allow(non_upper_case_globals)]
18 #[allow(non_camel_case_types)]
19 #[allow(non_snake_case)]
20 #[allow(dead_code)]
21 mod dll_notification_sys {
22     use std::io;
23 
24     use winapi::shared::minwindef::ULONG;
25     use winapi::shared::ntdef::NTSTATUS;
26     use winapi::shared::ntdef::PCUNICODE_STRING;
27     use winapi::shared::ntstatus::STATUS_SUCCESS;
28     use winapi::um::libloaderapi::GetModuleHandleA;
29     use winapi::um::libloaderapi::GetProcAddress;
30     use winapi::um::winnt::CHAR;
31     use winapi::um::winnt::PVOID;
32 
33     #[repr(C)]
34     pub union _LDR_DLL_NOTIFICATION_DATA {
35         pub Loaded: LDR_DLL_LOADED_NOTIFICATION_DATA,
36         pub Unloaded: LDR_DLL_UNLOADED_NOTIFICATION_DATA,
37     }
38     pub type LDR_DLL_NOTIFICATION_DATA = _LDR_DLL_NOTIFICATION_DATA;
39     pub type PLDR_DLL_NOTIFICATION_DATA = *mut LDR_DLL_NOTIFICATION_DATA;
40 
41     #[repr(C)]
42     #[derive(Debug, Copy, Clone)]
43     pub struct _LDR_DLL_LOADED_NOTIFICATION_DATA {
44         pub Flags: ULONG,                  // Reserved.
45         pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module.
46         pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module.
47         pub DllBase: PVOID,                // A pointer to the base address for the DLL in memory.
48         pub SizeOfImage: ULONG,            // The size of the DLL image, in bytes.
49     }
50     pub type LDR_DLL_LOADED_NOTIFICATION_DATA = _LDR_DLL_LOADED_NOTIFICATION_DATA;
51     pub type PLDR_DLL_LOADED_NOTIFICATION_DATA = *mut LDR_DLL_LOADED_NOTIFICATION_DATA;
52 
53     #[repr(C)]
54     #[derive(Debug, Copy, Clone)]
55     pub struct _LDR_DLL_UNLOADED_NOTIFICATION_DATA {
56         pub Flags: ULONG,                  // Reserved.
57         pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module.
58         pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module.
59         pub DllBase: PVOID,                // A pointer to the base address for the DLL in memory.
60         pub SizeOfImage: ULONG,            // The size of the DLL image, in bytes.
61     }
62     pub type LDR_DLL_UNLOADED_NOTIFICATION_DATA = _LDR_DLL_UNLOADED_NOTIFICATION_DATA;
63     pub type PLDR_DLL_UNLOADED_NOTIFICATION_DATA = *mut LDR_DLL_UNLOADED_NOTIFICATION_DATA;
64 
65     pub const LDR_DLL_NOTIFICATION_REASON_LOADED: ULONG = 1;
66     pub const LDR_DLL_NOTIFICATION_REASON_UNLOADED: ULONG = 2;
67 
68     const NTDLL: &[u8] = b"ntdll\0";
69     const LDR_REGISTER_DLL_NOTIFICATION: &[u8] = b"LdrRegisterDllNotification\0";
70     const LDR_UNREGISTER_DLL_NOTIFICATION: &[u8] = b"LdrUnregisterDllNotification\0";
71 
72     pub type LdrDllNotification = unsafe extern "C" fn(
73         NotificationReason: ULONG,
74         NotificationData: PLDR_DLL_NOTIFICATION_DATA,
75         Context: PVOID,
76     );
77 
78     pub type FnLdrRegisterDllNotification =
79         unsafe extern "C" fn(ULONG, LdrDllNotification, PVOID, *mut PVOID) -> NTSTATUS;
80     pub type FnLdrUnregisterDllNotification = unsafe extern "C" fn(PVOID) -> NTSTATUS;
81 
82     extern "C" {
RtlNtStatusToDosError(Status: NTSTATUS) -> ULONG83         pub fn RtlNtStatusToDosError(Status: NTSTATUS) -> ULONG;
84     }
85 
86     /// Wrapper for the NTDLL `LdrRegisterDllNotification` function. Dynamically
87     /// gets the address of the function and invokes the function with the given
88     /// arguments.
89     ///
90     /// # Safety
91     /// Unsafe as this function does not verify its arguments; the caller is
92     /// expected to verify the safety as if invoking the underlying C function.
LdrRegisterDllNotification( Flags: ULONG, NotificationFunction: LdrDllNotification, Context: PVOID, Cookie: *mut PVOID, ) -> io::Result<()>93     pub unsafe fn LdrRegisterDllNotification(
94         Flags: ULONG,
95         NotificationFunction: LdrDllNotification,
96         Context: PVOID,
97         Cookie: *mut PVOID,
98     ) -> io::Result<()> {
99         let proc_addr = GetProcAddress(
100             /* hModule= */
101             GetModuleHandleA(/* lpModuleName= */ NTDLL.as_ptr() as *const CHAR),
102             /* lpProcName= */
103             LDR_REGISTER_DLL_NOTIFICATION.as_ptr() as *const CHAR,
104         );
105         if proc_addr.is_null() {
106             return Err(std::io::Error::last_os_error());
107         }
108         let ldr_register_dll_notification: FnLdrRegisterDllNotification =
109             std::mem::transmute(proc_addr);
110         let ret = ldr_register_dll_notification(Flags, NotificationFunction, Context, Cookie);
111         if ret != STATUS_SUCCESS {
112             return Err(io::Error::from_raw_os_error(
113                 RtlNtStatusToDosError(/* Status= */ ret) as i32,
114             ));
115         };
116         Ok(())
117     }
118 
119     /// Wrapper for the NTDLL `LdrUnregisterDllNotification` function. Dynamically
120     /// gets the address of the function and invokes the function with the given
121     /// arguments.
122     ///
123     /// # Safety
124     /// Unsafe as this function does not verify its arguments; the caller is
125     /// expected to verify the safety as if invoking the underlying C function.
LdrUnregisterDllNotification(Cookie: PVOID) -> io::Result<()>126     pub unsafe fn LdrUnregisterDllNotification(Cookie: PVOID) -> io::Result<()> {
127         let proc_addr = GetProcAddress(
128             /* hModule= */
129             GetModuleHandleA(/* lpModuleName= */ NTDLL.as_ptr() as *const CHAR),
130             /* lpProcName= */
131             LDR_UNREGISTER_DLL_NOTIFICATION.as_ptr() as *const CHAR,
132         );
133         if proc_addr.is_null() {
134             return Err(std::io::Error::last_os_error());
135         }
136         let ldr_unregister_dll_notification: FnLdrUnregisterDllNotification =
137             std::mem::transmute(proc_addr);
138         let ret = ldr_unregister_dll_notification(Cookie);
139         if ret != STATUS_SUCCESS {
140             return Err(io::Error::from_raw_os_error(
141                 RtlNtStatusToDosError(/* Status= */ ret) as i32,
142             ));
143         };
144         Ok(())
145     }
146 }
147 
148 use dll_notification_sys::*;
149 
150 #[derive(Debug)]
151 pub struct DllNotificationData {
152     pub full_dll_name: OsString,
153     pub base_dll_name: OsString,
154 }
155 
156 /// Callback context wrapper for DLL load notification functions.
157 ///
158 /// This struct provides a wrapper for invoking a function-like type any time a
159 /// DLL is loaded in the current process. This is done in a type-safe way,
160 /// provided that users of this struct observe some safety invariants.
161 ///
162 /// # Safety
163 /// The struct instance must not be used once it has been registered as a
164 /// notification target. The callback function assumes that it has a mutable
165 /// reference to the struct instance. Only once the callback is unregistered is
166 /// it safe to re-use the struct instance.
167 struct CallbackContext<F1, F2>
168 where
169     F1: FnMut(DllNotificationData),
170     F2: FnMut(DllNotificationData),
171 {
172     loaded_callback: F1,
173     unloaded_callback: F2,
174 }
175 
176 impl<F1, F2> CallbackContext<F1, F2>
177 where
178     F1: FnMut(DllNotificationData),
179     F2: FnMut(DllNotificationData),
180 {
181     /// Create a new `CallbackContext` with the two callback functions. Takes
182     /// two callbacks, a `loaded_callback` which is called when a DLL is
183     /// loaded, and `unloaded_callback` which is called when a DLL is unloaded.
new(loaded_callback: F1, unloaded_callback: F2) -> Self184     pub fn new(loaded_callback: F1, unloaded_callback: F2) -> Self {
185         CallbackContext {
186             loaded_callback,
187             unloaded_callback,
188         }
189     }
190 
191     /// Provides a notification function that can be passed to the
192     /// `LdrRegisterDllNotification` function.
get_notification_function(&self) -> LdrDllNotification193     pub fn get_notification_function(&self) -> LdrDllNotification {
194         Self::notification_function
195     }
196 
197     /// A notification function with C linkage. This function assumes that it
198     /// has exclusive access to the instance of the struct passed through the
199     /// `context` parameter.
notification_function( notification_reason: ULONG, notification_data: PLDR_DLL_NOTIFICATION_DATA, context: PVOID, )200     extern "C" fn notification_function(
201         notification_reason: ULONG,
202         notification_data: PLDR_DLL_NOTIFICATION_DATA,
203         context: PVOID,
204     ) {
205         let callback_context =
206             // SAFETY: The DLLWatcher guarantees that the CallbackContext instance is not null and
207             // that we have exclusive access to it.
208             unsafe { (context as *mut Self).as_mut() }.expect("context was null");
209 
210         assert!(!notification_data.is_null());
211 
212         match notification_reason {
213             LDR_DLL_NOTIFICATION_REASON_LOADED => {
214                 // SAFETY: We know that the LDR_DLL_NOTIFICATION_DATA union contains the
215                 // LDR_DLL_LOADED_NOTIFICATION_DATA because we got
216                 // LDR_DLL_NOTIFICATION_REASON_LOADED as the notification reason.
217                 let loaded = unsafe { &mut (*notification_data).Loaded };
218 
219                 assert!(!loaded.BaseDllName.is_null());
220 
221                 // SAFETY: We assert that the pointer is not null and expect that the OS has
222                 // provided a valid UNICODE_STRING struct.
223                 let base_dll_name = unsafe { unicode_string_to_os_string(&*loaded.BaseDllName) };
224 
225                 assert!(!loaded.FullDllName.is_null());
226 
227                 // SAFETY: We assert that the pointer is not null and expect that the OS has
228                 // provided a valid UNICODE_STRING struct.
229                 let full_dll_name = unsafe { unicode_string_to_os_string(&*loaded.FullDllName) };
230 
231                 (callback_context.loaded_callback)(DllNotificationData {
232                     base_dll_name,
233                     full_dll_name,
234                 });
235             }
236             LDR_DLL_NOTIFICATION_REASON_UNLOADED => {
237                 // SAFETY: We know that the LDR_DLL_NOTIFICATION_DATA union contains the
238                 // LDR_DLL_UNLOADED_NOTIFICATION_DATA because we got
239                 // LDR_DLL_NOTIFICATION_REASON_UNLOADED as the notification reason.
240                 let unloaded = unsafe { &mut (*notification_data).Unloaded };
241 
242                 assert!(!unloaded.BaseDllName.is_null());
243 
244                 // SAFETY: We assert that the pointer is not null and expect that the OS has
245                 // provided a valid UNICODE_STRING struct.
246                 let base_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.BaseDllName) };
247 
248                 assert!(!unloaded.FullDllName.is_null());
249 
250                 // SAFETY: We assert that the pointer is not null and expect that the OS has
251                 // provided a valid UNICODE_STRING struct.
252                 let full_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.FullDllName) };
253 
254                 (callback_context.unloaded_callback)(DllNotificationData {
255                     base_dll_name,
256                     full_dll_name,
257                 })
258             }
259             n => panic!("invalid value \"{}\" for dll notification reason", n),
260         }
261     }
262 }
263 
264 /// DLL watcher for monitoring DLL loads/unloads.
265 ///
266 /// Provides a method to invoke a function-like type any time a DLL
267 /// is loaded or unloaded in the current process.
268 pub struct DllWatcher<F1, F2>
269 where
270     F1: FnMut(DllNotificationData),
271     F2: FnMut(DllNotificationData),
272 {
273     context: Box<CallbackContext<F1, F2>>,
274     cookie: Option<ptr::NonNull<c_void>>,
275 }
276 
277 impl<F1, F2> DllWatcher<F1, F2>
278 where
279     F1: FnMut(DllNotificationData),
280     F2: FnMut(DllNotificationData),
281 {
282     /// Create a new `DllWatcher` with the two callback functions. Takes two
283     /// callbacks, a `loaded_callback` which is called when a DLL is loaded,
284     /// and `unloaded_callback` which is called when a DLL is unloaded.
new(loaded_callback: F1, unloaded_callback: F2) -> io::Result<Self>285     pub fn new(loaded_callback: F1, unloaded_callback: F2) -> io::Result<Self> {
286         let mut watcher = Self {
287             context: Box::new(CallbackContext::new(loaded_callback, unloaded_callback)),
288             cookie: None,
289         };
290         let mut cookie: PVOID = ptr::null_mut();
291         // SAFETY: We guarantee that the notification function that we register will have exclusive
292         // access to the context.
293         unsafe {
294             LdrRegisterDllNotification(
295                 /* Flags= */ 0,
296                 /* NotificationFunction= */ watcher.context.get_notification_function(),
297                 /* Context= */
298                 &mut *watcher.context as *mut CallbackContext<F1, F2> as PVOID,
299                 /* Cookie= */ &mut cookie as *mut PVOID,
300             )?
301         };
302         watcher.cookie = ptr::NonNull::new(cookie);
303         Ok(watcher)
304     }
305 
unregister_dll_notification(&mut self) -> io::Result<()>306     fn unregister_dll_notification(&mut self) -> io::Result<()> {
307         if let Some(c) = self.cookie.take() {
308             // SAFETY: We guarantee that `Cookie` was previously initialized.
309             unsafe {
310                 LdrUnregisterDllNotification(/* Cookie= */ c.as_ptr() as PVOID)?
311             }
312         }
313 
314         Ok(())
315     }
316 }
317 
318 impl<F1, F2> Drop for DllWatcher<F1, F2>
319 where
320     F1: FnMut(DllNotificationData),
321     F2: FnMut(DllNotificationData),
322 {
drop(&mut self)323     fn drop(&mut self) {
324         self.unregister_dll_notification()
325             .expect("error unregistering dll notification");
326     }
327 }
328 
329 #[cfg(test)]
330 mod tests {
331     use std::collections::HashSet;
332     use std::ffi::CString;
333     use std::io;
334 
335     use winapi::shared::minwindef::FALSE;
336     use winapi::shared::minwindef::TRUE;
337     use winapi::um::handleapi::CloseHandle;
338     use winapi::um::libloaderapi::FreeLibrary;
339     use winapi::um::libloaderapi::LoadLibraryA;
340     use winapi::um::synchapi::CreateEventA;
341     use winapi::um::synchapi::SetEvent;
342     use winapi::um::synchapi::WaitForSingleObject;
343     use winapi::um::winbase::WAIT_OBJECT_0;
344 
345     use super::*;
346 
347     // Arbitrarily chosen DLLs for load/unload test. Chosen because they're
348     // hopefully esoteric enough that they're probably not already loaded in
349     // the process so we can test load/unload notifications.
350     //
351     // Using a single DLL can lead to flakiness; since the tests are run in the
352     // same process, it can be hard to rely on the OS to clean up the DLL loaded
353     // by one test before the other test runs. Using a different DLL makes the
354     // tests more independent.
355     const TEST_DLL_NAME_1: &str = "Imagehlp.dll";
356     const TEST_DLL_NAME_2: &str = "dbghelp.dll";
357 
358     #[test]
load_dll()359     fn load_dll() {
360         let test_dll_name = CString::new(TEST_DLL_NAME_1).expect("failed to create CString");
361         let mut loaded_dlls: HashSet<OsString> = HashSet::new();
362         let h_module = {
363             let _watcher = DllWatcher::new(
364                 |data| {
365                     loaded_dlls.insert(data.base_dll_name);
366                 },
367                 |_data| (),
368             )
369             .expect("failed to create DllWatcher");
370             // SAFETY: We pass a valid C string in to the function.
371             unsafe { LoadLibraryA(test_dll_name.as_ptr()) }
372         };
373         assert!(
374             !h_module.is_null(),
375             "failed to load {}: {}",
376             TEST_DLL_NAME_1,
377             io::Error::last_os_error()
378         );
379         assert!(
380             !loaded_dlls.is_empty(),
381             "no DLL loads recorded by DLL watcher"
382         );
383         assert!(
384             loaded_dlls.contains::<OsString>(&(TEST_DLL_NAME_1.to_owned().into())),
385             "{} load wasn't recorded by DLL watcher",
386             TEST_DLL_NAME_1
387         );
388         // SAFETY: We initialized h_module with a LoadLibraryA call.
389         let success = unsafe { FreeLibrary(h_module) } > 0;
390         assert!(
391             success,
392             "failed to free {}: {}",
393             TEST_DLL_NAME_1,
394             io::Error::last_os_error(),
395         )
396     }
397 
398     #[test]
unload_dll()399     fn unload_dll() {
400         let mut unloaded_dlls: HashSet<OsString> = HashSet::new();
401         let event =
402             // SAFETY: No pointers are passed. The handle may leak if the test fails.
403             unsafe { CreateEventA(std::ptr::null_mut(), TRUE, FALSE, std::ptr::null_mut()) };
404         assert!(
405             !event.is_null(),
406             "failed to create event; event was NULL: {}",
407             io::Error::last_os_error()
408         );
409         {
410             let test_dll_name = CString::new(TEST_DLL_NAME_2).expect("failed to create CString");
411             let _watcher = DllWatcher::new(
412                 |_data| (),
413                 |data| {
414                     unloaded_dlls.insert(data.base_dll_name);
415                     // SAFETY: We assert that the event is valid above.
416                     unsafe { SetEvent(event) };
417                 },
418             )
419             .expect("failed to create DllWatcher");
420             // SAFETY: We pass a valid C string in to the function.
421             let h_module = unsafe { LoadLibraryA(test_dll_name.as_ptr()) };
422             assert!(
423                 !h_module.is_null(),
424                 "failed to load {}: {}",
425                 TEST_DLL_NAME_2,
426                 io::Error::last_os_error()
427             );
428             // SAFETY: We initialized h_module with a LoadLibraryA call.
429             let success = unsafe { FreeLibrary(h_module) } > 0;
430             assert!(
431                 success,
432                 "failed to free {}: {}",
433                 TEST_DLL_NAME_2,
434                 io::Error::last_os_error(),
435             )
436         };
437         // SAFETY: We assert that the event is valid above.
438         assert_eq!(unsafe { WaitForSingleObject(event, 5000) }, WAIT_OBJECT_0);
439         assert!(
440             !unloaded_dlls.is_empty(),
441             "no DLL unloads recorded by DLL watcher"
442         );
443         assert!(
444             unloaded_dlls.contains::<OsString>(&(TEST_DLL_NAME_2.to_owned().into())),
445             "{} unload wasn't recorded by DLL watcher",
446             TEST_DLL_NAME_2
447         );
448         // SAFETY: We assert that the event is valid above.
449         unsafe { CloseHandle(event) };
450     }
451 }
452