xref: /aosp_15_r20/external/crosvm/win_util/src/dll_notification.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1*bb4ee6a4SAndroid Build Coastguard Worker // Copyright 2022 The ChromiumOS Authors
2*bb4ee6a4SAndroid Build Coastguard Worker // Use of this source code is governed by a BSD-style license that can be
3*bb4ee6a4SAndroid Build Coastguard Worker // found in the LICENSE file.
4*bb4ee6a4SAndroid Build Coastguard Worker 
5*bb4ee6a4SAndroid Build Coastguard Worker use std::ffi::c_void;
6*bb4ee6a4SAndroid Build Coastguard Worker use std::ffi::OsString;
7*bb4ee6a4SAndroid Build Coastguard Worker use std::io;
8*bb4ee6a4SAndroid Build Coastguard Worker use std::ptr;
9*bb4ee6a4SAndroid Build Coastguard Worker 
10*bb4ee6a4SAndroid Build Coastguard Worker use winapi::shared::minwindef::ULONG;
11*bb4ee6a4SAndroid Build Coastguard Worker use winapi::um::winnt::PVOID;
12*bb4ee6a4SAndroid Build Coastguard Worker 
13*bb4ee6a4SAndroid Build Coastguard Worker use super::unicode_string_to_os_string;
14*bb4ee6a4SAndroid Build Coastguard Worker 
15*bb4ee6a4SAndroid Build Coastguard Worker // Required for Windows API FFI bindings, as the names of the FFI structs and
16*bb4ee6a4SAndroid Build Coastguard Worker // functions get called out by the linter.
17*bb4ee6a4SAndroid Build Coastguard Worker #[allow(non_upper_case_globals)]
18*bb4ee6a4SAndroid Build Coastguard Worker #[allow(non_camel_case_types)]
19*bb4ee6a4SAndroid Build Coastguard Worker #[allow(non_snake_case)]
20*bb4ee6a4SAndroid Build Coastguard Worker #[allow(dead_code)]
21*bb4ee6a4SAndroid Build Coastguard Worker mod dll_notification_sys {
22*bb4ee6a4SAndroid Build Coastguard Worker     use std::io;
23*bb4ee6a4SAndroid Build Coastguard Worker 
24*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::shared::minwindef::ULONG;
25*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::shared::ntdef::NTSTATUS;
26*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::shared::ntdef::PCUNICODE_STRING;
27*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::shared::ntstatus::STATUS_SUCCESS;
28*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::um::libloaderapi::GetModuleHandleA;
29*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::um::libloaderapi::GetProcAddress;
30*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::um::winnt::CHAR;
31*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::um::winnt::PVOID;
32*bb4ee6a4SAndroid Build Coastguard Worker 
33*bb4ee6a4SAndroid Build Coastguard Worker     #[repr(C)]
34*bb4ee6a4SAndroid Build Coastguard Worker     pub union _LDR_DLL_NOTIFICATION_DATA {
35*bb4ee6a4SAndroid Build Coastguard Worker         pub Loaded: LDR_DLL_LOADED_NOTIFICATION_DATA,
36*bb4ee6a4SAndroid Build Coastguard Worker         pub Unloaded: LDR_DLL_UNLOADED_NOTIFICATION_DATA,
37*bb4ee6a4SAndroid Build Coastguard Worker     }
38*bb4ee6a4SAndroid Build Coastguard Worker     pub type LDR_DLL_NOTIFICATION_DATA = _LDR_DLL_NOTIFICATION_DATA;
39*bb4ee6a4SAndroid Build Coastguard Worker     pub type PLDR_DLL_NOTIFICATION_DATA = *mut LDR_DLL_NOTIFICATION_DATA;
40*bb4ee6a4SAndroid Build Coastguard Worker 
41*bb4ee6a4SAndroid Build Coastguard Worker     #[repr(C)]
42*bb4ee6a4SAndroid Build Coastguard Worker     #[derive(Debug, Copy, Clone)]
43*bb4ee6a4SAndroid Build Coastguard Worker     pub struct _LDR_DLL_LOADED_NOTIFICATION_DATA {
44*bb4ee6a4SAndroid Build Coastguard Worker         pub Flags: ULONG,                  // Reserved.
45*bb4ee6a4SAndroid Build Coastguard Worker         pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module.
46*bb4ee6a4SAndroid Build Coastguard Worker         pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module.
47*bb4ee6a4SAndroid Build Coastguard Worker         pub DllBase: PVOID,                // A pointer to the base address for the DLL in memory.
48*bb4ee6a4SAndroid Build Coastguard Worker         pub SizeOfImage: ULONG,            // The size of the DLL image, in bytes.
49*bb4ee6a4SAndroid Build Coastguard Worker     }
50*bb4ee6a4SAndroid Build Coastguard Worker     pub type LDR_DLL_LOADED_NOTIFICATION_DATA = _LDR_DLL_LOADED_NOTIFICATION_DATA;
51*bb4ee6a4SAndroid Build Coastguard Worker     pub type PLDR_DLL_LOADED_NOTIFICATION_DATA = *mut LDR_DLL_LOADED_NOTIFICATION_DATA;
52*bb4ee6a4SAndroid Build Coastguard Worker 
53*bb4ee6a4SAndroid Build Coastguard Worker     #[repr(C)]
54*bb4ee6a4SAndroid Build Coastguard Worker     #[derive(Debug, Copy, Clone)]
55*bb4ee6a4SAndroid Build Coastguard Worker     pub struct _LDR_DLL_UNLOADED_NOTIFICATION_DATA {
56*bb4ee6a4SAndroid Build Coastguard Worker         pub Flags: ULONG,                  // Reserved.
57*bb4ee6a4SAndroid Build Coastguard Worker         pub FullDllName: PCUNICODE_STRING, // The full path name of the DLL module.
58*bb4ee6a4SAndroid Build Coastguard Worker         pub BaseDllName: PCUNICODE_STRING, // The base file name of the DLL module.
59*bb4ee6a4SAndroid Build Coastguard Worker         pub DllBase: PVOID,                // A pointer to the base address for the DLL in memory.
60*bb4ee6a4SAndroid Build Coastguard Worker         pub SizeOfImage: ULONG,            // The size of the DLL image, in bytes.
61*bb4ee6a4SAndroid Build Coastguard Worker     }
62*bb4ee6a4SAndroid Build Coastguard Worker     pub type LDR_DLL_UNLOADED_NOTIFICATION_DATA = _LDR_DLL_UNLOADED_NOTIFICATION_DATA;
63*bb4ee6a4SAndroid Build Coastguard Worker     pub type PLDR_DLL_UNLOADED_NOTIFICATION_DATA = *mut LDR_DLL_UNLOADED_NOTIFICATION_DATA;
64*bb4ee6a4SAndroid Build Coastguard Worker 
65*bb4ee6a4SAndroid Build Coastguard Worker     pub const LDR_DLL_NOTIFICATION_REASON_LOADED: ULONG = 1;
66*bb4ee6a4SAndroid Build Coastguard Worker     pub const LDR_DLL_NOTIFICATION_REASON_UNLOADED: ULONG = 2;
67*bb4ee6a4SAndroid Build Coastguard Worker 
68*bb4ee6a4SAndroid Build Coastguard Worker     const NTDLL: &[u8] = b"ntdll\0";
69*bb4ee6a4SAndroid Build Coastguard Worker     const LDR_REGISTER_DLL_NOTIFICATION: &[u8] = b"LdrRegisterDllNotification\0";
70*bb4ee6a4SAndroid Build Coastguard Worker     const LDR_UNREGISTER_DLL_NOTIFICATION: &[u8] = b"LdrUnregisterDllNotification\0";
71*bb4ee6a4SAndroid Build Coastguard Worker 
72*bb4ee6a4SAndroid Build Coastguard Worker     pub type LdrDllNotification = unsafe extern "C" fn(
73*bb4ee6a4SAndroid Build Coastguard Worker         NotificationReason: ULONG,
74*bb4ee6a4SAndroid Build Coastguard Worker         NotificationData: PLDR_DLL_NOTIFICATION_DATA,
75*bb4ee6a4SAndroid Build Coastguard Worker         Context: PVOID,
76*bb4ee6a4SAndroid Build Coastguard Worker     );
77*bb4ee6a4SAndroid Build Coastguard Worker 
78*bb4ee6a4SAndroid Build Coastguard Worker     pub type FnLdrRegisterDllNotification =
79*bb4ee6a4SAndroid Build Coastguard Worker         unsafe extern "C" fn(ULONG, LdrDllNotification, PVOID, *mut PVOID) -> NTSTATUS;
80*bb4ee6a4SAndroid Build Coastguard Worker     pub type FnLdrUnregisterDllNotification = unsafe extern "C" fn(PVOID) -> NTSTATUS;
81*bb4ee6a4SAndroid Build Coastguard Worker 
82*bb4ee6a4SAndroid Build Coastguard Worker     extern "C" {
RtlNtStatusToDosError(Status: NTSTATUS) -> ULONG83*bb4ee6a4SAndroid Build Coastguard Worker         pub fn RtlNtStatusToDosError(Status: NTSTATUS) -> ULONG;
84*bb4ee6a4SAndroid Build Coastguard Worker     }
85*bb4ee6a4SAndroid Build Coastguard Worker 
86*bb4ee6a4SAndroid Build Coastguard Worker     /// Wrapper for the NTDLL `LdrRegisterDllNotification` function. Dynamically
87*bb4ee6a4SAndroid Build Coastguard Worker     /// gets the address of the function and invokes the function with the given
88*bb4ee6a4SAndroid Build Coastguard Worker     /// arguments.
89*bb4ee6a4SAndroid Build Coastguard Worker     ///
90*bb4ee6a4SAndroid Build Coastguard Worker     /// # Safety
91*bb4ee6a4SAndroid Build Coastguard Worker     /// Unsafe as this function does not verify its arguments; the caller is
92*bb4ee6a4SAndroid Build Coastguard Worker     /// expected to verify the safety as if invoking the underlying C function.
LdrRegisterDllNotification( Flags: ULONG, NotificationFunction: LdrDllNotification, Context: PVOID, Cookie: *mut PVOID, ) -> io::Result<()>93*bb4ee6a4SAndroid Build Coastguard Worker     pub unsafe fn LdrRegisterDllNotification(
94*bb4ee6a4SAndroid Build Coastguard Worker         Flags: ULONG,
95*bb4ee6a4SAndroid Build Coastguard Worker         NotificationFunction: LdrDllNotification,
96*bb4ee6a4SAndroid Build Coastguard Worker         Context: PVOID,
97*bb4ee6a4SAndroid Build Coastguard Worker         Cookie: *mut PVOID,
98*bb4ee6a4SAndroid Build Coastguard Worker     ) -> io::Result<()> {
99*bb4ee6a4SAndroid Build Coastguard Worker         let proc_addr = GetProcAddress(
100*bb4ee6a4SAndroid Build Coastguard Worker             /* hModule= */
101*bb4ee6a4SAndroid Build Coastguard Worker             GetModuleHandleA(/* lpModuleName= */ NTDLL.as_ptr() as *const CHAR),
102*bb4ee6a4SAndroid Build Coastguard Worker             /* lpProcName= */
103*bb4ee6a4SAndroid Build Coastguard Worker             LDR_REGISTER_DLL_NOTIFICATION.as_ptr() as *const CHAR,
104*bb4ee6a4SAndroid Build Coastguard Worker         );
105*bb4ee6a4SAndroid Build Coastguard Worker         if proc_addr.is_null() {
106*bb4ee6a4SAndroid Build Coastguard Worker             return Err(std::io::Error::last_os_error());
107*bb4ee6a4SAndroid Build Coastguard Worker         }
108*bb4ee6a4SAndroid Build Coastguard Worker         let ldr_register_dll_notification: FnLdrRegisterDllNotification =
109*bb4ee6a4SAndroid Build Coastguard Worker             std::mem::transmute(proc_addr);
110*bb4ee6a4SAndroid Build Coastguard Worker         let ret = ldr_register_dll_notification(Flags, NotificationFunction, Context, Cookie);
111*bb4ee6a4SAndroid Build Coastguard Worker         if ret != STATUS_SUCCESS {
112*bb4ee6a4SAndroid Build Coastguard Worker             return Err(io::Error::from_raw_os_error(
113*bb4ee6a4SAndroid Build Coastguard Worker                 RtlNtStatusToDosError(/* Status= */ ret) as i32,
114*bb4ee6a4SAndroid Build Coastguard Worker             ));
115*bb4ee6a4SAndroid Build Coastguard Worker         };
116*bb4ee6a4SAndroid Build Coastguard Worker         Ok(())
117*bb4ee6a4SAndroid Build Coastguard Worker     }
118*bb4ee6a4SAndroid Build Coastguard Worker 
119*bb4ee6a4SAndroid Build Coastguard Worker     /// Wrapper for the NTDLL `LdrUnregisterDllNotification` function. Dynamically
120*bb4ee6a4SAndroid Build Coastguard Worker     /// gets the address of the function and invokes the function with the given
121*bb4ee6a4SAndroid Build Coastguard Worker     /// arguments.
122*bb4ee6a4SAndroid Build Coastguard Worker     ///
123*bb4ee6a4SAndroid Build Coastguard Worker     /// # Safety
124*bb4ee6a4SAndroid Build Coastguard Worker     /// Unsafe as this function does not verify its arguments; the caller is
125*bb4ee6a4SAndroid Build Coastguard Worker     /// expected to verify the safety as if invoking the underlying C function.
LdrUnregisterDllNotification(Cookie: PVOID) -> io::Result<()>126*bb4ee6a4SAndroid Build Coastguard Worker     pub unsafe fn LdrUnregisterDllNotification(Cookie: PVOID) -> io::Result<()> {
127*bb4ee6a4SAndroid Build Coastguard Worker         let proc_addr = GetProcAddress(
128*bb4ee6a4SAndroid Build Coastguard Worker             /* hModule= */
129*bb4ee6a4SAndroid Build Coastguard Worker             GetModuleHandleA(/* lpModuleName= */ NTDLL.as_ptr() as *const CHAR),
130*bb4ee6a4SAndroid Build Coastguard Worker             /* lpProcName= */
131*bb4ee6a4SAndroid Build Coastguard Worker             LDR_UNREGISTER_DLL_NOTIFICATION.as_ptr() as *const CHAR,
132*bb4ee6a4SAndroid Build Coastguard Worker         );
133*bb4ee6a4SAndroid Build Coastguard Worker         if proc_addr.is_null() {
134*bb4ee6a4SAndroid Build Coastguard Worker             return Err(std::io::Error::last_os_error());
135*bb4ee6a4SAndroid Build Coastguard Worker         }
136*bb4ee6a4SAndroid Build Coastguard Worker         let ldr_unregister_dll_notification: FnLdrUnregisterDllNotification =
137*bb4ee6a4SAndroid Build Coastguard Worker             std::mem::transmute(proc_addr);
138*bb4ee6a4SAndroid Build Coastguard Worker         let ret = ldr_unregister_dll_notification(Cookie);
139*bb4ee6a4SAndroid Build Coastguard Worker         if ret != STATUS_SUCCESS {
140*bb4ee6a4SAndroid Build Coastguard Worker             return Err(io::Error::from_raw_os_error(
141*bb4ee6a4SAndroid Build Coastguard Worker                 RtlNtStatusToDosError(/* Status= */ ret) as i32,
142*bb4ee6a4SAndroid Build Coastguard Worker             ));
143*bb4ee6a4SAndroid Build Coastguard Worker         };
144*bb4ee6a4SAndroid Build Coastguard Worker         Ok(())
145*bb4ee6a4SAndroid Build Coastguard Worker     }
146*bb4ee6a4SAndroid Build Coastguard Worker }
147*bb4ee6a4SAndroid Build Coastguard Worker 
148*bb4ee6a4SAndroid Build Coastguard Worker use dll_notification_sys::*;
149*bb4ee6a4SAndroid Build Coastguard Worker 
150*bb4ee6a4SAndroid Build Coastguard Worker #[derive(Debug)]
151*bb4ee6a4SAndroid Build Coastguard Worker pub struct DllNotificationData {
152*bb4ee6a4SAndroid Build Coastguard Worker     pub full_dll_name: OsString,
153*bb4ee6a4SAndroid Build Coastguard Worker     pub base_dll_name: OsString,
154*bb4ee6a4SAndroid Build Coastguard Worker }
155*bb4ee6a4SAndroid Build Coastguard Worker 
156*bb4ee6a4SAndroid Build Coastguard Worker /// Callback context wrapper for DLL load notification functions.
157*bb4ee6a4SAndroid Build Coastguard Worker ///
158*bb4ee6a4SAndroid Build Coastguard Worker /// This struct provides a wrapper for invoking a function-like type any time a
159*bb4ee6a4SAndroid Build Coastguard Worker /// DLL is loaded in the current process. This is done in a type-safe way,
160*bb4ee6a4SAndroid Build Coastguard Worker /// provided that users of this struct observe some safety invariants.
161*bb4ee6a4SAndroid Build Coastguard Worker ///
162*bb4ee6a4SAndroid Build Coastguard Worker /// # Safety
163*bb4ee6a4SAndroid Build Coastguard Worker /// The struct instance must not be used once it has been registered as a
164*bb4ee6a4SAndroid Build Coastguard Worker /// notification target. The callback function assumes that it has a mutable
165*bb4ee6a4SAndroid Build Coastguard Worker /// reference to the struct instance. Only once the callback is unregistered is
166*bb4ee6a4SAndroid Build Coastguard Worker /// it safe to re-use the struct instance.
167*bb4ee6a4SAndroid Build Coastguard Worker struct CallbackContext<F1, F2>
168*bb4ee6a4SAndroid Build Coastguard Worker where
169*bb4ee6a4SAndroid Build Coastguard Worker     F1: FnMut(DllNotificationData),
170*bb4ee6a4SAndroid Build Coastguard Worker     F2: FnMut(DllNotificationData),
171*bb4ee6a4SAndroid Build Coastguard Worker {
172*bb4ee6a4SAndroid Build Coastguard Worker     loaded_callback: F1,
173*bb4ee6a4SAndroid Build Coastguard Worker     unloaded_callback: F2,
174*bb4ee6a4SAndroid Build Coastguard Worker }
175*bb4ee6a4SAndroid Build Coastguard Worker 
176*bb4ee6a4SAndroid Build Coastguard Worker impl<F1, F2> CallbackContext<F1, F2>
177*bb4ee6a4SAndroid Build Coastguard Worker where
178*bb4ee6a4SAndroid Build Coastguard Worker     F1: FnMut(DllNotificationData),
179*bb4ee6a4SAndroid Build Coastguard Worker     F2: FnMut(DllNotificationData),
180*bb4ee6a4SAndroid Build Coastguard Worker {
181*bb4ee6a4SAndroid Build Coastguard Worker     /// Create a new `CallbackContext` with the two callback functions. Takes
182*bb4ee6a4SAndroid Build Coastguard Worker     /// two callbacks, a `loaded_callback` which is called when a DLL is
183*bb4ee6a4SAndroid Build Coastguard Worker     /// loaded, and `unloaded_callback` which is called when a DLL is unloaded.
new(loaded_callback: F1, unloaded_callback: F2) -> Self184*bb4ee6a4SAndroid Build Coastguard Worker     pub fn new(loaded_callback: F1, unloaded_callback: F2) -> Self {
185*bb4ee6a4SAndroid Build Coastguard Worker         CallbackContext {
186*bb4ee6a4SAndroid Build Coastguard Worker             loaded_callback,
187*bb4ee6a4SAndroid Build Coastguard Worker             unloaded_callback,
188*bb4ee6a4SAndroid Build Coastguard Worker         }
189*bb4ee6a4SAndroid Build Coastguard Worker     }
190*bb4ee6a4SAndroid Build Coastguard Worker 
191*bb4ee6a4SAndroid Build Coastguard Worker     /// Provides a notification function that can be passed to the
192*bb4ee6a4SAndroid Build Coastguard Worker     /// `LdrRegisterDllNotification` function.
get_notification_function(&self) -> LdrDllNotification193*bb4ee6a4SAndroid Build Coastguard Worker     pub fn get_notification_function(&self) -> LdrDllNotification {
194*bb4ee6a4SAndroid Build Coastguard Worker         Self::notification_function
195*bb4ee6a4SAndroid Build Coastguard Worker     }
196*bb4ee6a4SAndroid Build Coastguard Worker 
197*bb4ee6a4SAndroid Build Coastguard Worker     /// A notification function with C linkage. This function assumes that it
198*bb4ee6a4SAndroid Build Coastguard Worker     /// has exclusive access to the instance of the struct passed through the
199*bb4ee6a4SAndroid Build Coastguard Worker     /// `context` parameter.
notification_function( notification_reason: ULONG, notification_data: PLDR_DLL_NOTIFICATION_DATA, context: PVOID, )200*bb4ee6a4SAndroid Build Coastguard Worker     extern "C" fn notification_function(
201*bb4ee6a4SAndroid Build Coastguard Worker         notification_reason: ULONG,
202*bb4ee6a4SAndroid Build Coastguard Worker         notification_data: PLDR_DLL_NOTIFICATION_DATA,
203*bb4ee6a4SAndroid Build Coastguard Worker         context: PVOID,
204*bb4ee6a4SAndroid Build Coastguard Worker     ) {
205*bb4ee6a4SAndroid Build Coastguard Worker         let callback_context =
206*bb4ee6a4SAndroid Build Coastguard Worker             // SAFETY: The DLLWatcher guarantees that the CallbackContext instance is not null and
207*bb4ee6a4SAndroid Build Coastguard Worker             // that we have exclusive access to it.
208*bb4ee6a4SAndroid Build Coastguard Worker             unsafe { (context as *mut Self).as_mut() }.expect("context was null");
209*bb4ee6a4SAndroid Build Coastguard Worker 
210*bb4ee6a4SAndroid Build Coastguard Worker         assert!(!notification_data.is_null());
211*bb4ee6a4SAndroid Build Coastguard Worker 
212*bb4ee6a4SAndroid Build Coastguard Worker         match notification_reason {
213*bb4ee6a4SAndroid Build Coastguard Worker             LDR_DLL_NOTIFICATION_REASON_LOADED => {
214*bb4ee6a4SAndroid Build Coastguard Worker                 // SAFETY: We know that the LDR_DLL_NOTIFICATION_DATA union contains the
215*bb4ee6a4SAndroid Build Coastguard Worker                 // LDR_DLL_LOADED_NOTIFICATION_DATA because we got
216*bb4ee6a4SAndroid Build Coastguard Worker                 // LDR_DLL_NOTIFICATION_REASON_LOADED as the notification reason.
217*bb4ee6a4SAndroid Build Coastguard Worker                 let loaded = unsafe { &mut (*notification_data).Loaded };
218*bb4ee6a4SAndroid Build Coastguard Worker 
219*bb4ee6a4SAndroid Build Coastguard Worker                 assert!(!loaded.BaseDllName.is_null());
220*bb4ee6a4SAndroid Build Coastguard Worker 
221*bb4ee6a4SAndroid Build Coastguard Worker                 // SAFETY: We assert that the pointer is not null and expect that the OS has
222*bb4ee6a4SAndroid Build Coastguard Worker                 // provided a valid UNICODE_STRING struct.
223*bb4ee6a4SAndroid Build Coastguard Worker                 let base_dll_name = unsafe { unicode_string_to_os_string(&*loaded.BaseDllName) };
224*bb4ee6a4SAndroid Build Coastguard Worker 
225*bb4ee6a4SAndroid Build Coastguard Worker                 assert!(!loaded.FullDllName.is_null());
226*bb4ee6a4SAndroid Build Coastguard Worker 
227*bb4ee6a4SAndroid Build Coastguard Worker                 // SAFETY: We assert that the pointer is not null and expect that the OS has
228*bb4ee6a4SAndroid Build Coastguard Worker                 // provided a valid UNICODE_STRING struct.
229*bb4ee6a4SAndroid Build Coastguard Worker                 let full_dll_name = unsafe { unicode_string_to_os_string(&*loaded.FullDllName) };
230*bb4ee6a4SAndroid Build Coastguard Worker 
231*bb4ee6a4SAndroid Build Coastguard Worker                 (callback_context.loaded_callback)(DllNotificationData {
232*bb4ee6a4SAndroid Build Coastguard Worker                     base_dll_name,
233*bb4ee6a4SAndroid Build Coastguard Worker                     full_dll_name,
234*bb4ee6a4SAndroid Build Coastguard Worker                 });
235*bb4ee6a4SAndroid Build Coastguard Worker             }
236*bb4ee6a4SAndroid Build Coastguard Worker             LDR_DLL_NOTIFICATION_REASON_UNLOADED => {
237*bb4ee6a4SAndroid Build Coastguard Worker                 // SAFETY: We know that the LDR_DLL_NOTIFICATION_DATA union contains the
238*bb4ee6a4SAndroid Build Coastguard Worker                 // LDR_DLL_UNLOADED_NOTIFICATION_DATA because we got
239*bb4ee6a4SAndroid Build Coastguard Worker                 // LDR_DLL_NOTIFICATION_REASON_UNLOADED as the notification reason.
240*bb4ee6a4SAndroid Build Coastguard Worker                 let unloaded = unsafe { &mut (*notification_data).Unloaded };
241*bb4ee6a4SAndroid Build Coastguard Worker 
242*bb4ee6a4SAndroid Build Coastguard Worker                 assert!(!unloaded.BaseDllName.is_null());
243*bb4ee6a4SAndroid Build Coastguard Worker 
244*bb4ee6a4SAndroid Build Coastguard Worker                 // SAFETY: We assert that the pointer is not null and expect that the OS has
245*bb4ee6a4SAndroid Build Coastguard Worker                 // provided a valid UNICODE_STRING struct.
246*bb4ee6a4SAndroid Build Coastguard Worker                 let base_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.BaseDllName) };
247*bb4ee6a4SAndroid Build Coastguard Worker 
248*bb4ee6a4SAndroid Build Coastguard Worker                 assert!(!unloaded.FullDllName.is_null());
249*bb4ee6a4SAndroid Build Coastguard Worker 
250*bb4ee6a4SAndroid Build Coastguard Worker                 // SAFETY: We assert that the pointer is not null and expect that the OS has
251*bb4ee6a4SAndroid Build Coastguard Worker                 // provided a valid UNICODE_STRING struct.
252*bb4ee6a4SAndroid Build Coastguard Worker                 let full_dll_name = unsafe { unicode_string_to_os_string(&*unloaded.FullDllName) };
253*bb4ee6a4SAndroid Build Coastguard Worker 
254*bb4ee6a4SAndroid Build Coastguard Worker                 (callback_context.unloaded_callback)(DllNotificationData {
255*bb4ee6a4SAndroid Build Coastguard Worker                     base_dll_name,
256*bb4ee6a4SAndroid Build Coastguard Worker                     full_dll_name,
257*bb4ee6a4SAndroid Build Coastguard Worker                 })
258*bb4ee6a4SAndroid Build Coastguard Worker             }
259*bb4ee6a4SAndroid Build Coastguard Worker             n => panic!("invalid value \"{}\" for dll notification reason", n),
260*bb4ee6a4SAndroid Build Coastguard Worker         }
261*bb4ee6a4SAndroid Build Coastguard Worker     }
262*bb4ee6a4SAndroid Build Coastguard Worker }
263*bb4ee6a4SAndroid Build Coastguard Worker 
264*bb4ee6a4SAndroid Build Coastguard Worker /// DLL watcher for monitoring DLL loads/unloads.
265*bb4ee6a4SAndroid Build Coastguard Worker ///
266*bb4ee6a4SAndroid Build Coastguard Worker /// Provides a method to invoke a function-like type any time a DLL
267*bb4ee6a4SAndroid Build Coastguard Worker /// is loaded or unloaded in the current process.
268*bb4ee6a4SAndroid Build Coastguard Worker pub struct DllWatcher<F1, F2>
269*bb4ee6a4SAndroid Build Coastguard Worker where
270*bb4ee6a4SAndroid Build Coastguard Worker     F1: FnMut(DllNotificationData),
271*bb4ee6a4SAndroid Build Coastguard Worker     F2: FnMut(DllNotificationData),
272*bb4ee6a4SAndroid Build Coastguard Worker {
273*bb4ee6a4SAndroid Build Coastguard Worker     context: Box<CallbackContext<F1, F2>>,
274*bb4ee6a4SAndroid Build Coastguard Worker     cookie: Option<ptr::NonNull<c_void>>,
275*bb4ee6a4SAndroid Build Coastguard Worker }
276*bb4ee6a4SAndroid Build Coastguard Worker 
277*bb4ee6a4SAndroid Build Coastguard Worker impl<F1, F2> DllWatcher<F1, F2>
278*bb4ee6a4SAndroid Build Coastguard Worker where
279*bb4ee6a4SAndroid Build Coastguard Worker     F1: FnMut(DllNotificationData),
280*bb4ee6a4SAndroid Build Coastguard Worker     F2: FnMut(DllNotificationData),
281*bb4ee6a4SAndroid Build Coastguard Worker {
282*bb4ee6a4SAndroid Build Coastguard Worker     /// Create a new `DllWatcher` with the two callback functions. Takes two
283*bb4ee6a4SAndroid Build Coastguard Worker     /// callbacks, a `loaded_callback` which is called when a DLL is loaded,
284*bb4ee6a4SAndroid Build Coastguard Worker     /// and `unloaded_callback` which is called when a DLL is unloaded.
new(loaded_callback: F1, unloaded_callback: F2) -> io::Result<Self>285*bb4ee6a4SAndroid Build Coastguard Worker     pub fn new(loaded_callback: F1, unloaded_callback: F2) -> io::Result<Self> {
286*bb4ee6a4SAndroid Build Coastguard Worker         let mut watcher = Self {
287*bb4ee6a4SAndroid Build Coastguard Worker             context: Box::new(CallbackContext::new(loaded_callback, unloaded_callback)),
288*bb4ee6a4SAndroid Build Coastguard Worker             cookie: None,
289*bb4ee6a4SAndroid Build Coastguard Worker         };
290*bb4ee6a4SAndroid Build Coastguard Worker         let mut cookie: PVOID = ptr::null_mut();
291*bb4ee6a4SAndroid Build Coastguard Worker         // SAFETY: We guarantee that the notification function that we register will have exclusive
292*bb4ee6a4SAndroid Build Coastguard Worker         // access to the context.
293*bb4ee6a4SAndroid Build Coastguard Worker         unsafe {
294*bb4ee6a4SAndroid Build Coastguard Worker             LdrRegisterDllNotification(
295*bb4ee6a4SAndroid Build Coastguard Worker                 /* Flags= */ 0,
296*bb4ee6a4SAndroid Build Coastguard Worker                 /* NotificationFunction= */ watcher.context.get_notification_function(),
297*bb4ee6a4SAndroid Build Coastguard Worker                 /* Context= */
298*bb4ee6a4SAndroid Build Coastguard Worker                 &mut *watcher.context as *mut CallbackContext<F1, F2> as PVOID,
299*bb4ee6a4SAndroid Build Coastguard Worker                 /* Cookie= */ &mut cookie as *mut PVOID,
300*bb4ee6a4SAndroid Build Coastguard Worker             )?
301*bb4ee6a4SAndroid Build Coastguard Worker         };
302*bb4ee6a4SAndroid Build Coastguard Worker         watcher.cookie = ptr::NonNull::new(cookie);
303*bb4ee6a4SAndroid Build Coastguard Worker         Ok(watcher)
304*bb4ee6a4SAndroid Build Coastguard Worker     }
305*bb4ee6a4SAndroid Build Coastguard Worker 
unregister_dll_notification(&mut self) -> io::Result<()>306*bb4ee6a4SAndroid Build Coastguard Worker     fn unregister_dll_notification(&mut self) -> io::Result<()> {
307*bb4ee6a4SAndroid Build Coastguard Worker         if let Some(c) = self.cookie.take() {
308*bb4ee6a4SAndroid Build Coastguard Worker             // SAFETY: We guarantee that `Cookie` was previously initialized.
309*bb4ee6a4SAndroid Build Coastguard Worker             unsafe {
310*bb4ee6a4SAndroid Build Coastguard Worker                 LdrUnregisterDllNotification(/* Cookie= */ c.as_ptr() as PVOID)?
311*bb4ee6a4SAndroid Build Coastguard Worker             }
312*bb4ee6a4SAndroid Build Coastguard Worker         }
313*bb4ee6a4SAndroid Build Coastguard Worker 
314*bb4ee6a4SAndroid Build Coastguard Worker         Ok(())
315*bb4ee6a4SAndroid Build Coastguard Worker     }
316*bb4ee6a4SAndroid Build Coastguard Worker }
317*bb4ee6a4SAndroid Build Coastguard Worker 
318*bb4ee6a4SAndroid Build Coastguard Worker impl<F1, F2> Drop for DllWatcher<F1, F2>
319*bb4ee6a4SAndroid Build Coastguard Worker where
320*bb4ee6a4SAndroid Build Coastguard Worker     F1: FnMut(DllNotificationData),
321*bb4ee6a4SAndroid Build Coastguard Worker     F2: FnMut(DllNotificationData),
322*bb4ee6a4SAndroid Build Coastguard Worker {
drop(&mut self)323*bb4ee6a4SAndroid Build Coastguard Worker     fn drop(&mut self) {
324*bb4ee6a4SAndroid Build Coastguard Worker         self.unregister_dll_notification()
325*bb4ee6a4SAndroid Build Coastguard Worker             .expect("error unregistering dll notification");
326*bb4ee6a4SAndroid Build Coastguard Worker     }
327*bb4ee6a4SAndroid Build Coastguard Worker }
328*bb4ee6a4SAndroid Build Coastguard Worker 
329*bb4ee6a4SAndroid Build Coastguard Worker #[cfg(test)]
330*bb4ee6a4SAndroid Build Coastguard Worker mod tests {
331*bb4ee6a4SAndroid Build Coastguard Worker     use std::collections::HashSet;
332*bb4ee6a4SAndroid Build Coastguard Worker     use std::ffi::CString;
333*bb4ee6a4SAndroid Build Coastguard Worker     use std::io;
334*bb4ee6a4SAndroid Build Coastguard Worker 
335*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::shared::minwindef::FALSE;
336*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::shared::minwindef::TRUE;
337*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::um::handleapi::CloseHandle;
338*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::um::libloaderapi::FreeLibrary;
339*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::um::libloaderapi::LoadLibraryA;
340*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::um::synchapi::CreateEventA;
341*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::um::synchapi::SetEvent;
342*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::um::synchapi::WaitForSingleObject;
343*bb4ee6a4SAndroid Build Coastguard Worker     use winapi::um::winbase::WAIT_OBJECT_0;
344*bb4ee6a4SAndroid Build Coastguard Worker 
345*bb4ee6a4SAndroid Build Coastguard Worker     use super::*;
346*bb4ee6a4SAndroid Build Coastguard Worker 
347*bb4ee6a4SAndroid Build Coastguard Worker     // Arbitrarily chosen DLLs for load/unload test. Chosen because they're
348*bb4ee6a4SAndroid Build Coastguard Worker     // hopefully esoteric enough that they're probably not already loaded in
349*bb4ee6a4SAndroid Build Coastguard Worker     // the process so we can test load/unload notifications.
350*bb4ee6a4SAndroid Build Coastguard Worker     //
351*bb4ee6a4SAndroid Build Coastguard Worker     // Using a single DLL can lead to flakiness; since the tests are run in the
352*bb4ee6a4SAndroid Build Coastguard Worker     // same process, it can be hard to rely on the OS to clean up the DLL loaded
353*bb4ee6a4SAndroid Build Coastguard Worker     // by one test before the other test runs. Using a different DLL makes the
354*bb4ee6a4SAndroid Build Coastguard Worker     // tests more independent.
355*bb4ee6a4SAndroid Build Coastguard Worker     const TEST_DLL_NAME_1: &str = "Imagehlp.dll";
356*bb4ee6a4SAndroid Build Coastguard Worker     const TEST_DLL_NAME_2: &str = "dbghelp.dll";
357*bb4ee6a4SAndroid Build Coastguard Worker 
358*bb4ee6a4SAndroid Build Coastguard Worker     #[test]
load_dll()359*bb4ee6a4SAndroid Build Coastguard Worker     fn load_dll() {
360*bb4ee6a4SAndroid Build Coastguard Worker         let test_dll_name = CString::new(TEST_DLL_NAME_1).expect("failed to create CString");
361*bb4ee6a4SAndroid Build Coastguard Worker         let mut loaded_dlls: HashSet<OsString> = HashSet::new();
362*bb4ee6a4SAndroid Build Coastguard Worker         let h_module = {
363*bb4ee6a4SAndroid Build Coastguard Worker             let _watcher = DllWatcher::new(
364*bb4ee6a4SAndroid Build Coastguard Worker                 |data| {
365*bb4ee6a4SAndroid Build Coastguard Worker                     loaded_dlls.insert(data.base_dll_name);
366*bb4ee6a4SAndroid Build Coastguard Worker                 },
367*bb4ee6a4SAndroid Build Coastguard Worker                 |_data| (),
368*bb4ee6a4SAndroid Build Coastguard Worker             )
369*bb4ee6a4SAndroid Build Coastguard Worker             .expect("failed to create DllWatcher");
370*bb4ee6a4SAndroid Build Coastguard Worker             // SAFETY: We pass a valid C string in to the function.
371*bb4ee6a4SAndroid Build Coastguard Worker             unsafe { LoadLibraryA(test_dll_name.as_ptr()) }
372*bb4ee6a4SAndroid Build Coastguard Worker         };
373*bb4ee6a4SAndroid Build Coastguard Worker         assert!(
374*bb4ee6a4SAndroid Build Coastguard Worker             !h_module.is_null(),
375*bb4ee6a4SAndroid Build Coastguard Worker             "failed to load {}: {}",
376*bb4ee6a4SAndroid Build Coastguard Worker             TEST_DLL_NAME_1,
377*bb4ee6a4SAndroid Build Coastguard Worker             io::Error::last_os_error()
378*bb4ee6a4SAndroid Build Coastguard Worker         );
379*bb4ee6a4SAndroid Build Coastguard Worker         assert!(
380*bb4ee6a4SAndroid Build Coastguard Worker             !loaded_dlls.is_empty(),
381*bb4ee6a4SAndroid Build Coastguard Worker             "no DLL loads recorded by DLL watcher"
382*bb4ee6a4SAndroid Build Coastguard Worker         );
383*bb4ee6a4SAndroid Build Coastguard Worker         assert!(
384*bb4ee6a4SAndroid Build Coastguard Worker             loaded_dlls.contains::<OsString>(&(TEST_DLL_NAME_1.to_owned().into())),
385*bb4ee6a4SAndroid Build Coastguard Worker             "{} load wasn't recorded by DLL watcher",
386*bb4ee6a4SAndroid Build Coastguard Worker             TEST_DLL_NAME_1
387*bb4ee6a4SAndroid Build Coastguard Worker         );
388*bb4ee6a4SAndroid Build Coastguard Worker         // SAFETY: We initialized h_module with a LoadLibraryA call.
389*bb4ee6a4SAndroid Build Coastguard Worker         let success = unsafe { FreeLibrary(h_module) } > 0;
390*bb4ee6a4SAndroid Build Coastguard Worker         assert!(
391*bb4ee6a4SAndroid Build Coastguard Worker             success,
392*bb4ee6a4SAndroid Build Coastguard Worker             "failed to free {}: {}",
393*bb4ee6a4SAndroid Build Coastguard Worker             TEST_DLL_NAME_1,
394*bb4ee6a4SAndroid Build Coastguard Worker             io::Error::last_os_error(),
395*bb4ee6a4SAndroid Build Coastguard Worker         )
396*bb4ee6a4SAndroid Build Coastguard Worker     }
397*bb4ee6a4SAndroid Build Coastguard Worker 
398*bb4ee6a4SAndroid Build Coastguard Worker     #[test]
unload_dll()399*bb4ee6a4SAndroid Build Coastguard Worker     fn unload_dll() {
400*bb4ee6a4SAndroid Build Coastguard Worker         let mut unloaded_dlls: HashSet<OsString> = HashSet::new();
401*bb4ee6a4SAndroid Build Coastguard Worker         let event =
402*bb4ee6a4SAndroid Build Coastguard Worker             // SAFETY: No pointers are passed. The handle may leak if the test fails.
403*bb4ee6a4SAndroid Build Coastguard Worker             unsafe { CreateEventA(std::ptr::null_mut(), TRUE, FALSE, std::ptr::null_mut()) };
404*bb4ee6a4SAndroid Build Coastguard Worker         assert!(
405*bb4ee6a4SAndroid Build Coastguard Worker             !event.is_null(),
406*bb4ee6a4SAndroid Build Coastguard Worker             "failed to create event; event was NULL: {}",
407*bb4ee6a4SAndroid Build Coastguard Worker             io::Error::last_os_error()
408*bb4ee6a4SAndroid Build Coastguard Worker         );
409*bb4ee6a4SAndroid Build Coastguard Worker         {
410*bb4ee6a4SAndroid Build Coastguard Worker             let test_dll_name = CString::new(TEST_DLL_NAME_2).expect("failed to create CString");
411*bb4ee6a4SAndroid Build Coastguard Worker             let _watcher = DllWatcher::new(
412*bb4ee6a4SAndroid Build Coastguard Worker                 |_data| (),
413*bb4ee6a4SAndroid Build Coastguard Worker                 |data| {
414*bb4ee6a4SAndroid Build Coastguard Worker                     unloaded_dlls.insert(data.base_dll_name);
415*bb4ee6a4SAndroid Build Coastguard Worker                     // SAFETY: We assert that the event is valid above.
416*bb4ee6a4SAndroid Build Coastguard Worker                     unsafe { SetEvent(event) };
417*bb4ee6a4SAndroid Build Coastguard Worker                 },
418*bb4ee6a4SAndroid Build Coastguard Worker             )
419*bb4ee6a4SAndroid Build Coastguard Worker             .expect("failed to create DllWatcher");
420*bb4ee6a4SAndroid Build Coastguard Worker             // SAFETY: We pass a valid C string in to the function.
421*bb4ee6a4SAndroid Build Coastguard Worker             let h_module = unsafe { LoadLibraryA(test_dll_name.as_ptr()) };
422*bb4ee6a4SAndroid Build Coastguard Worker             assert!(
423*bb4ee6a4SAndroid Build Coastguard Worker                 !h_module.is_null(),
424*bb4ee6a4SAndroid Build Coastguard Worker                 "failed to load {}: {}",
425*bb4ee6a4SAndroid Build Coastguard Worker                 TEST_DLL_NAME_2,
426*bb4ee6a4SAndroid Build Coastguard Worker                 io::Error::last_os_error()
427*bb4ee6a4SAndroid Build Coastguard Worker             );
428*bb4ee6a4SAndroid Build Coastguard Worker             // SAFETY: We initialized h_module with a LoadLibraryA call.
429*bb4ee6a4SAndroid Build Coastguard Worker             let success = unsafe { FreeLibrary(h_module) } > 0;
430*bb4ee6a4SAndroid Build Coastguard Worker             assert!(
431*bb4ee6a4SAndroid Build Coastguard Worker                 success,
432*bb4ee6a4SAndroid Build Coastguard Worker                 "failed to free {}: {}",
433*bb4ee6a4SAndroid Build Coastguard Worker                 TEST_DLL_NAME_2,
434*bb4ee6a4SAndroid Build Coastguard Worker                 io::Error::last_os_error(),
435*bb4ee6a4SAndroid Build Coastguard Worker             )
436*bb4ee6a4SAndroid Build Coastguard Worker         };
437*bb4ee6a4SAndroid Build Coastguard Worker         // SAFETY: We assert that the event is valid above.
438*bb4ee6a4SAndroid Build Coastguard Worker         assert_eq!(unsafe { WaitForSingleObject(event, 5000) }, WAIT_OBJECT_0);
439*bb4ee6a4SAndroid Build Coastguard Worker         assert!(
440*bb4ee6a4SAndroid Build Coastguard Worker             !unloaded_dlls.is_empty(),
441*bb4ee6a4SAndroid Build Coastguard Worker             "no DLL unloads recorded by DLL watcher"
442*bb4ee6a4SAndroid Build Coastguard Worker         );
443*bb4ee6a4SAndroid Build Coastguard Worker         assert!(
444*bb4ee6a4SAndroid Build Coastguard Worker             unloaded_dlls.contains::<OsString>(&(TEST_DLL_NAME_2.to_owned().into())),
445*bb4ee6a4SAndroid Build Coastguard Worker             "{} unload wasn't recorded by DLL watcher",
446*bb4ee6a4SAndroid Build Coastguard Worker             TEST_DLL_NAME_2
447*bb4ee6a4SAndroid Build Coastguard Worker         );
448*bb4ee6a4SAndroid Build Coastguard Worker         // SAFETY: We assert that the event is valid above.
449*bb4ee6a4SAndroid Build Coastguard Worker         unsafe { CloseHandle(event) };
450*bb4ee6a4SAndroid Build Coastguard Worker     }
451*bb4ee6a4SAndroid Build Coastguard Worker }
452