xref: /aosp_15_r20/external/crosvm/base/src/sys/windows/platform_timer_utils.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::io;
6 use std::mem::MaybeUninit;
7 use std::sync::Once;
8 use std::thread::sleep;
9 use std::time::Duration;
10 use std::time::Instant;
11 
12 use win_util::win32_string;
13 use win_util::win32_wide_string;
14 use winapi::shared::minwindef;
15 use winapi::shared::minwindef::HINSTANCE;
16 use winapi::shared::minwindef::HMODULE;
17 use winapi::shared::minwindef::PULONG;
18 use winapi::shared::ntdef::NTSTATUS;
19 use winapi::shared::ntdef::ULONG;
20 use winapi::shared::ntstatus::STATUS_SUCCESS;
21 use winapi::um::libloaderapi;
22 use winapi::um::mmsystem::TIMERR_NOERROR;
23 use winapi::um::timeapi::timeBeginPeriod;
24 use winapi::um::timeapi::timeEndPeriod;
25 use winapi::um::winnt::BOOLEAN;
26 
27 use crate::warn;
28 use crate::Error;
29 use crate::Result;
30 
31 static NT_INIT: Once = Once::new();
32 static mut NT_LIBRARY: MaybeUninit<HMODULE> = MaybeUninit::uninit();
33 
34 #[inline]
init_ntdll() -> Result<HINSTANCE>35 fn init_ntdll() -> Result<HINSTANCE> {
36     NT_INIT.call_once(|| {
37         // SAFETY: return value is checked.
38         unsafe {
39             *NT_LIBRARY.as_mut_ptr() =
40                 libloaderapi::LoadLibraryW(win32_wide_string("ntdll").as_ptr());
41 
42             if NT_LIBRARY.assume_init().is_null() {
43                 warn!("Failed to load ntdll: {}", Error::last());
44             }
45         };
46     });
47 
48     // SAFETY: NT_LIBRARY initialized above.
49     let handle = unsafe { NT_LIBRARY.assume_init() };
50     if handle.is_null() {
51         Err(Error::from(io::Error::new(
52             io::ErrorKind::NotFound,
53             "ntdll failed to load",
54         )))
55     } else {
56         Ok(handle)
57     }
58 }
59 
get_symbol(handle: HMODULE, proc_name: &str) -> Result<*mut minwindef::__some_function>60 fn get_symbol(handle: HMODULE, proc_name: &str) -> Result<*mut minwindef::__some_function> {
61     // SAFETY: return value is checked.
62     let symbol = unsafe { libloaderapi::GetProcAddress(handle, win32_string(proc_name).as_ptr()) };
63     if symbol.is_null() {
64         Err(Error::last())
65     } else {
66         Ok(symbol)
67     }
68 }
69 
70 /// Returns the resolution of timers on the host (current_res, max_res).
nt_query_timer_resolution() -> Result<(Duration, Duration)>71 pub fn nt_query_timer_resolution() -> Result<(Duration, Duration)> {
72     let handle = init_ntdll()?;
73 
74     // SAFETY: trivially safe
75     let func = unsafe {
76         std::mem::transmute::<
77             *mut minwindef::__some_function,
78             extern "system" fn(PULONG, PULONG, PULONG) -> NTSTATUS,
79         >(get_symbol(handle, "NtQueryTimerResolution")?)
80     };
81 
82     let mut min_res: u32 = 0;
83     let mut max_res: u32 = 0;
84     let mut current_res: u32 = 0;
85     let ret = func(
86         &mut min_res as *mut u32,
87         &mut max_res as *mut u32,
88         &mut current_res as *mut u32,
89     );
90 
91     if ret != STATUS_SUCCESS {
92         Err(Error::from(io::Error::new(
93             io::ErrorKind::Other,
94             "NtQueryTimerResolution failed",
95         )))
96     } else {
97         Ok((
98             Duration::from_nanos((current_res as u64) * 100),
99             Duration::from_nanos((max_res as u64) * 100),
100         ))
101     }
102 }
103 
nt_set_timer_resolution(resolution: Duration) -> Result<()>104 pub fn nt_set_timer_resolution(resolution: Duration) -> Result<()> {
105     let handle = init_ntdll()?;
106     // SAFETY: trivially safe
107     let func = unsafe {
108         std::mem::transmute::<
109             *mut minwindef::__some_function,
110             extern "system" fn(ULONG, BOOLEAN, PULONG) -> NTSTATUS,
111         >(get_symbol(handle, "NtSetTimerResolution")?)
112     };
113 
114     let requested_res: u32 = (resolution.as_nanos() / 100) as u32;
115     let mut current_res: u32 = 0;
116     let ret = func(
117         requested_res,
118         1, /* true */
119         &mut current_res as *mut u32,
120     );
121 
122     if ret != STATUS_SUCCESS {
123         Err(Error::from(io::Error::new(
124             io::ErrorKind::Other,
125             "NtSetTimerResolution failed",
126         )))
127     } else {
128         Ok(())
129     }
130 }
131 
132 /// Measures the timer resolution by taking the 90th percentile wall time of 1ms sleeps.
measure_timer_resolution() -> Duration133 pub fn measure_timer_resolution() -> Duration {
134     let mut durations = Vec::with_capacity(100);
135     for _ in 0..100 {
136         let start = Instant::now();
137         // Windows cannot support sleeps shorter than 1ms.
138         sleep(Duration::from_millis(1));
139         durations.push(Instant::now() - start);
140     }
141 
142     durations.sort();
143     durations[89]
144 }
145 
146 /// Note that Durations below 1ms are not supported and will panic.
set_time_period(res: Duration, begin: bool) -> Result<()>147 pub fn set_time_period(res: Duration, begin: bool) -> Result<()> {
148     if res.as_millis() < 1 {
149         panic!(
150             "time(Begin|End)Period does not support values below 1ms, but {:?} was requested.",
151             res
152         );
153     }
154     if res.as_millis() > u32::MAX as u128 {
155         panic!("time(Begin|End)Period does not support values above u32::MAX.",);
156     }
157 
158     let ret = if begin {
159         // SAFETY: Trivially safe. Note that the casts are safe because we know res is within u32's
160         // range.
161         unsafe { timeBeginPeriod(res.as_millis() as u32) }
162     } else {
163         // SAFETY: Trivially safe. Note that the casts are safe because we know res is within u32's
164         // range.
165         unsafe { timeEndPeriod(res.as_millis() as u32) }
166     };
167     if ret != TIMERR_NOERROR {
168         // These functions only have two return codes: NOERROR and NOCANDO.
169         Err(Error::from(io::Error::new(
170             io::ErrorKind::InvalidInput,
171             "timeBegin/EndPeriod failed",
172         )))
173     } else {
174         Ok(())
175     }
176 }
177 
178 /// Note that these tests cannot run on Kokoro due to random slowness in that environment.
179 #[cfg(test)]
180 mod tests {
181     use super::*;
182 
183     /// We're testing whether NtSetTimerResolution does what it says on the tin.
184     #[test]
185     #[ignore]
setting_nt_timer_resolution_changes_resolution()186     fn setting_nt_timer_resolution_changes_resolution() {
187         let (old_res, _) = nt_query_timer_resolution().unwrap();
188 
189         nt_set_timer_resolution(Duration::from_millis(1)).unwrap();
190         assert_res_within_bound(measure_timer_resolution());
191         nt_set_timer_resolution(old_res).unwrap();
192     }
193 
194     #[test]
195     #[ignore]
setting_timer_resolution_changes_resolution()196     fn setting_timer_resolution_changes_resolution() {
197         let res = Duration::from_millis(1);
198 
199         set_time_period(res, true).unwrap();
200         assert_res_within_bound(measure_timer_resolution());
201         set_time_period(res, false).unwrap();
202     }
203 
assert_res_within_bound(actual_res: Duration)204     fn assert_res_within_bound(actual_res: Duration) {
205         assert!(
206             actual_res <= Duration::from_millis(2),
207             "actual_res was {:?}, expected <= 2ms",
208             actual_res
209         );
210     }
211 }
212