xref: /aosp_15_r20/external/crosvm/base/src/sys/windows/read_write_wrappers.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::io;
6 
7 use win_util::fail_if_zero;
8 use winapi::shared::minwindef::DWORD;
9 use winapi::shared::minwindef::LPCVOID;
10 use winapi::shared::minwindef::LPVOID;
11 use winapi::shared::minwindef::TRUE;
12 use winapi::shared::winerror::ERROR_IO_PENDING;
13 use winapi::um::fileapi::ReadFile;
14 use winapi::um::fileapi::WriteFile;
15 use winapi::um::ioapiset::GetOverlappedResult;
16 use winapi::um::minwinbase::OVERLAPPED;
17 
18 use crate::AsRawDescriptor;
19 use crate::Event;
20 use crate::RawDescriptor;
21 
22 /// # Safety
23 /// 1. buf points to memory that will not be freed until the write operation completes.
24 /// 2. buf points to at least buf_len bytes.
write_file( handle: &dyn AsRawDescriptor, buf: *const u8, buf_len: usize, overlapped: Option<&mut OVERLAPPED>, ) -> io::Result<usize>25 pub unsafe fn write_file(
26     handle: &dyn AsRawDescriptor,
27     buf: *const u8,
28     buf_len: usize,
29     overlapped: Option<&mut OVERLAPPED>,
30 ) -> io::Result<usize> {
31     let is_overlapped = overlapped.is_some();
32 
33     // Safe because buf points to a valid region of memory whose size we have computed,
34     // pipe has not been closed (as it's managed by this object), and we check the return
35     // value for any errors
36     let mut bytes_written: DWORD = 0;
37     let success_flag = WriteFile(
38         handle.as_raw_descriptor(),
39         buf as LPCVOID,
40         buf_len
41             .try_into()
42             .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?,
43         match overlapped {
44             Some(_) => std::ptr::null_mut(),
45             None => &mut bytes_written,
46         },
47         match overlapped {
48             Some(v) => v,
49             None => std::ptr::null_mut(),
50         },
51     );
52 
53     if success_flag == 0 {
54         let err = io::Error::last_os_error();
55         if Some(ERROR_IO_PENDING as i32) == err.raw_os_error() && is_overlapped {
56             Ok(0)
57         } else {
58             Err(err)
59         }
60     } else {
61         Ok(bytes_written as usize)
62     }
63 }
64 
65 /// # Safety
66 /// 1. buf points to memory that will not be freed until the read operation completes.
67 /// 2. buf points to at least buf_len bytes.
read_file( handle: &dyn AsRawDescriptor, buf: *mut u8, buf_len: usize, overlapped: Option<&mut OVERLAPPED>, ) -> io::Result<usize>68 pub unsafe fn read_file(
69     handle: &dyn AsRawDescriptor,
70     buf: *mut u8,
71     buf_len: usize,
72     overlapped: Option<&mut OVERLAPPED>,
73 ) -> io::Result<usize> {
74     // Used to verify if ERROR_IO_PENDING should be an error.
75     let is_overlapped = overlapped.is_some();
76 
77     // Safe because we cap the size of the read to the size of the buffer
78     // and check the return code
79     let mut bytes_read: DWORD = 0;
80     let success_flag = ReadFile(
81         handle.as_raw_descriptor(),
82         buf as LPVOID,
83         buf_len as DWORD,
84         match overlapped {
85             Some(_) => std::ptr::null_mut(),
86             None => &mut bytes_read,
87         },
88         match overlapped {
89             Some(v) => v,
90             None => std::ptr::null_mut(),
91         },
92     );
93 
94     if success_flag == 0 {
95         let e = io::Error::last_os_error();
96         match e.raw_os_error() {
97             // ERROR_IO_PENDING, according the to docs, isn't really an error. This just means
98             // that the ReadFile operation hasn't completed. In this case,
99             // `get_overlapped_result` will wait until the operation is completed.
100             Some(error_code) if error_code == ERROR_IO_PENDING as i32 && is_overlapped => Ok(0),
101             _ => Err(e),
102         }
103     } else {
104         Ok(bytes_read as usize)
105     }
106 }
107 
set_overlapped_offset(overlapped: &mut OVERLAPPED, offset: u64)108 fn set_overlapped_offset(overlapped: &mut OVERLAPPED, offset: u64) {
109     // # Safety: Safe because overlapped is allocated, and we are manipulating non-overlapping
110     //           fields.
111     unsafe {
112         overlapped.u.s_mut().Offset = (offset & 0xffffffff) as DWORD;
113         overlapped.u.s_mut().OffsetHigh = (offset >> 32) as DWORD;
114     }
115 }
116 
117 // Creates a new `OVERLAPPED` struct with given, if any, offset and event
create_overlapped(offset: Option<u64>, event: Option<RawDescriptor>) -> OVERLAPPED118 pub fn create_overlapped(offset: Option<u64>, event: Option<RawDescriptor>) -> OVERLAPPED {
119     let mut overlapped = OVERLAPPED::default();
120     if let Some(offset) = offset {
121         set_overlapped_offset(&mut overlapped, offset);
122     }
123     if let Some(event) = event {
124         overlapped.hEvent = event;
125     }
126     overlapped
127 }
128 
129 /// Reads buf from given handle from offset in a blocking mode.
130 /// handle is expected to be opened in overlapped mode.
read_overlapped_blocking( handle: &dyn AsRawDescriptor, offset: u64, buf: &mut [u8], ) -> io::Result<usize>131 pub fn read_overlapped_blocking(
132     handle: &dyn AsRawDescriptor,
133     offset: u64,
134     buf: &mut [u8],
135 ) -> io::Result<usize> {
136     let mut size_transferred = 0;
137     let event = Event::new()?;
138     let mut overlapped = create_overlapped(Some(offset), Some(event.as_raw_descriptor()));
139 
140     // Safety: Safe because we check return values after the calls.
141     unsafe {
142         let _ = read_file(handle, buf.as_mut_ptr(), buf.len(), Some(&mut overlapped))?;
143         fail_if_zero!(GetOverlappedResult(
144             handle.as_raw_descriptor(),
145             &mut overlapped,
146             &mut size_transferred,
147             TRUE,
148         ));
149     }
150     Ok(size_transferred as usize)
151 }
152 
153 #[cfg(test)]
154 mod tests {
155     use std::fs::File;
156     use std::fs::OpenOptions;
157     use std::os::windows::fs::OpenOptionsExt;
158     use std::path::PathBuf;
159 
160     use tempfile::TempDir;
161     use winapi::um::winbase::FILE_FLAG_OVERLAPPED;
162 
163     use super::*;
tempfile_path() -> (PathBuf, TempDir)164     fn tempfile_path() -> (PathBuf, TempDir) {
165         let dir = tempfile::TempDir::new().unwrap();
166         let mut file_path = PathBuf::from(dir.path());
167         file_path.push("test");
168         (file_path, dir)
169     }
170 
open_overlapped(path: &PathBuf) -> File171     fn open_overlapped(path: &PathBuf) -> File {
172         OpenOptions::new()
173             .read(true)
174             .write(true)
175             .custom_flags(FILE_FLAG_OVERLAPPED)
176             .open(path)
177             .unwrap()
178     }
179 
180     #[test]
test_read_overlapped()181     fn test_read_overlapped() {
182         let (file_path, _tmpdir) = tempfile_path();
183         let data: [u8; 6] = [0, 1, 2, 3, 5, 6];
184         std::fs::write(&file_path, data).unwrap();
185 
186         let of = open_overlapped(&file_path);
187         let mut buf: [u8; 3] = [0; 3];
188         read_overlapped_blocking(&of, 3, &mut buf).unwrap();
189         assert_eq!(buf, data[3..6]);
190     }
191 }
192