1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::io;
6 use std::mem::MaybeUninit;
7 use std::sync::Once;
8 use std::thread::sleep;
9 use std::time::Duration;
10 use std::time::Instant;
11
12 use win_util::win32_string;
13 use win_util::win32_wide_string;
14 use winapi::shared::minwindef;
15 use winapi::shared::minwindef::HINSTANCE;
16 use winapi::shared::minwindef::HMODULE;
17 use winapi::shared::minwindef::PULONG;
18 use winapi::shared::ntdef::NTSTATUS;
19 use winapi::shared::ntdef::ULONG;
20 use winapi::shared::ntstatus::STATUS_SUCCESS;
21 use winapi::um::libloaderapi;
22 use winapi::um::mmsystem::TIMERR_NOERROR;
23 use winapi::um::timeapi::timeBeginPeriod;
24 use winapi::um::timeapi::timeEndPeriod;
25 use winapi::um::winnt::BOOLEAN;
26
27 use crate::warn;
28 use crate::Error;
29 use crate::Result;
30
31 static NT_INIT: Once = Once::new();
32 static mut NT_LIBRARY: MaybeUninit<HMODULE> = MaybeUninit::uninit();
33
34 #[inline]
init_ntdll() -> Result<HINSTANCE>35 fn init_ntdll() -> Result<HINSTANCE> {
36 NT_INIT.call_once(|| {
37 // SAFETY: return value is checked.
38 unsafe {
39 *NT_LIBRARY.as_mut_ptr() =
40 libloaderapi::LoadLibraryW(win32_wide_string("ntdll").as_ptr());
41
42 if NT_LIBRARY.assume_init().is_null() {
43 warn!("Failed to load ntdll: {}", Error::last());
44 }
45 };
46 });
47
48 // SAFETY: NT_LIBRARY initialized above.
49 let handle = unsafe { NT_LIBRARY.assume_init() };
50 if handle.is_null() {
51 Err(Error::from(io::Error::new(
52 io::ErrorKind::NotFound,
53 "ntdll failed to load",
54 )))
55 } else {
56 Ok(handle)
57 }
58 }
59
get_symbol(handle: HMODULE, proc_name: &str) -> Result<*mut minwindef::__some_function>60 fn get_symbol(handle: HMODULE, proc_name: &str) -> Result<*mut minwindef::__some_function> {
61 // SAFETY: return value is checked.
62 let symbol = unsafe { libloaderapi::GetProcAddress(handle, win32_string(proc_name).as_ptr()) };
63 if symbol.is_null() {
64 Err(Error::last())
65 } else {
66 Ok(symbol)
67 }
68 }
69
70 /// Returns the resolution of timers on the host (current_res, max_res).
nt_query_timer_resolution() -> Result<(Duration, Duration)>71 pub fn nt_query_timer_resolution() -> Result<(Duration, Duration)> {
72 let handle = init_ntdll()?;
73
74 // SAFETY: trivially safe
75 let func = unsafe {
76 std::mem::transmute::<
77 *mut minwindef::__some_function,
78 extern "system" fn(PULONG, PULONG, PULONG) -> NTSTATUS,
79 >(get_symbol(handle, "NtQueryTimerResolution")?)
80 };
81
82 let mut min_res: u32 = 0;
83 let mut max_res: u32 = 0;
84 let mut current_res: u32 = 0;
85 let ret = func(
86 &mut min_res as *mut u32,
87 &mut max_res as *mut u32,
88 &mut current_res as *mut u32,
89 );
90
91 if ret != STATUS_SUCCESS {
92 Err(Error::from(io::Error::new(
93 io::ErrorKind::Other,
94 "NtQueryTimerResolution failed",
95 )))
96 } else {
97 Ok((
98 Duration::from_nanos((current_res as u64) * 100),
99 Duration::from_nanos((max_res as u64) * 100),
100 ))
101 }
102 }
103
nt_set_timer_resolution(resolution: Duration) -> Result<()>104 pub fn nt_set_timer_resolution(resolution: Duration) -> Result<()> {
105 let handle = init_ntdll()?;
106 // SAFETY: trivially safe
107 let func = unsafe {
108 std::mem::transmute::<
109 *mut minwindef::__some_function,
110 extern "system" fn(ULONG, BOOLEAN, PULONG) -> NTSTATUS,
111 >(get_symbol(handle, "NtSetTimerResolution")?)
112 };
113
114 let requested_res: u32 = (resolution.as_nanos() / 100) as u32;
115 let mut current_res: u32 = 0;
116 let ret = func(
117 requested_res,
118 1, /* true */
119 &mut current_res as *mut u32,
120 );
121
122 if ret != STATUS_SUCCESS {
123 Err(Error::from(io::Error::new(
124 io::ErrorKind::Other,
125 "NtSetTimerResolution failed",
126 )))
127 } else {
128 Ok(())
129 }
130 }
131
132 /// Measures the timer resolution by taking the 90th percentile wall time of 1ms sleeps.
measure_timer_resolution() -> Duration133 pub fn measure_timer_resolution() -> Duration {
134 let mut durations = Vec::with_capacity(100);
135 for _ in 0..100 {
136 let start = Instant::now();
137 // Windows cannot support sleeps shorter than 1ms.
138 sleep(Duration::from_millis(1));
139 durations.push(Instant::now() - start);
140 }
141
142 durations.sort();
143 durations[89]
144 }
145
146 /// Note that Durations below 1ms are not supported and will panic.
set_time_period(res: Duration, begin: bool) -> Result<()>147 pub fn set_time_period(res: Duration, begin: bool) -> Result<()> {
148 if res.as_millis() < 1 {
149 panic!(
150 "time(Begin|End)Period does not support values below 1ms, but {:?} was requested.",
151 res
152 );
153 }
154 if res.as_millis() > u32::MAX as u128 {
155 panic!("time(Begin|End)Period does not support values above u32::MAX.",);
156 }
157
158 let ret = if begin {
159 // SAFETY: Trivially safe. Note that the casts are safe because we know res is within u32's
160 // range.
161 unsafe { timeBeginPeriod(res.as_millis() as u32) }
162 } else {
163 // SAFETY: Trivially safe. Note that the casts are safe because we know res is within u32's
164 // range.
165 unsafe { timeEndPeriod(res.as_millis() as u32) }
166 };
167 if ret != TIMERR_NOERROR {
168 // These functions only have two return codes: NOERROR and NOCANDO.
169 Err(Error::from(io::Error::new(
170 io::ErrorKind::InvalidInput,
171 "timeBegin/EndPeriod failed",
172 )))
173 } else {
174 Ok(())
175 }
176 }
177
178 /// Note that these tests cannot run on Kokoro due to random slowness in that environment.
179 #[cfg(test)]
180 mod tests {
181 use super::*;
182
183 /// We're testing whether NtSetTimerResolution does what it says on the tin.
184 #[test]
185 #[ignore]
setting_nt_timer_resolution_changes_resolution()186 fn setting_nt_timer_resolution_changes_resolution() {
187 let (old_res, _) = nt_query_timer_resolution().unwrap();
188
189 nt_set_timer_resolution(Duration::from_millis(1)).unwrap();
190 assert_res_within_bound(measure_timer_resolution());
191 nt_set_timer_resolution(old_res).unwrap();
192 }
193
194 #[test]
195 #[ignore]
setting_timer_resolution_changes_resolution()196 fn setting_timer_resolution_changes_resolution() {
197 let res = Duration::from_millis(1);
198
199 set_time_period(res, true).unwrap();
200 assert_res_within_bound(measure_timer_resolution());
201 set_time_period(res, false).unwrap();
202 }
203
assert_res_within_bound(actual_res: Duration)204 fn assert_res_within_bound(actual_res: Duration) {
205 assert!(
206 actual_res <= Duration::from_millis(2),
207 "actual_res was {:?}, expected <= 2ms",
208 actual_res
209 );
210 }
211 }
212