xref: /aosp_15_r20/external/crosvm/devices/src/virtio/iommu/sys/linux.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2022 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 pub mod vfio_wrapper;
6 
7 use std::cell::RefCell;
8 use std::collections::BTreeMap;
9 use std::fs::File;
10 use std::rc::Rc;
11 use std::sync::Arc;
12 
13 use base::error;
14 use base::MemoryMappingBuilder;
15 use base::TubeError;
16 use cros_async::AsyncTube;
17 use cros_async::Executor;
18 use sync::Mutex;
19 use vm_control::VirtioIOMMURequest;
20 use vm_control::VirtioIOMMUResponse;
21 use vm_control::VirtioIOMMUVfioCommand;
22 use vm_control::VirtioIOMMUVfioResult;
23 use vm_control::VmMemoryRegionId;
24 
25 use self::vfio_wrapper::VfioWrapper;
26 use crate::virtio::iommu::ipc_memory_mapper::IommuRequest;
27 use crate::virtio::iommu::ipc_memory_mapper::IommuResponse;
28 use crate::virtio::iommu::DmabufRegionEntry;
29 use crate::virtio::iommu::Result;
30 use crate::virtio::iommu::State;
31 use crate::virtio::IommuError;
32 use crate::VfioContainer;
33 
34 impl State {
handle_add_vfio_device( &mut self, endpoint_addr: u32, wrapper: VfioWrapper, ) -> VirtioIOMMUVfioResult35     pub(in crate::virtio::iommu) fn handle_add_vfio_device(
36         &mut self,
37         endpoint_addr: u32,
38         wrapper: VfioWrapper,
39     ) -> VirtioIOMMUVfioResult {
40         let exists = |endpoint_addr: u32| -> bool {
41             for endpoints_range in self.hp_endpoints_ranges.iter() {
42                 if endpoints_range.contains(&endpoint_addr) {
43                     return true;
44                 }
45             }
46             false
47         };
48 
49         if !exists(endpoint_addr) {
50             return VirtioIOMMUVfioResult::NotInPCIRanges;
51         }
52 
53         self.endpoints
54             .insert(endpoint_addr, Arc::new(Mutex::new(Box::new(wrapper))));
55         VirtioIOMMUVfioResult::Ok
56     }
57 
handle_del_vfio_device( &mut self, pci_address: u32, ) -> VirtioIOMMUVfioResult58     pub(in crate::virtio::iommu) fn handle_del_vfio_device(
59         &mut self,
60         pci_address: u32,
61     ) -> VirtioIOMMUVfioResult {
62         if self.endpoints.remove(&pci_address).is_none() {
63             error!("There is no vfio container of {}", pci_address);
64             return VirtioIOMMUVfioResult::NoSuchDevice;
65         }
66         if let Some(domain) = self.endpoint_map.remove(&pci_address) {
67             self.domain_map.remove(&domain);
68         }
69         VirtioIOMMUVfioResult::Ok
70     }
71 
handle_map_dmabuf( &mut self, region_id: VmMemoryRegionId, gpa: u64, size: u64, dma_buf: File, ) -> VirtioIOMMUVfioResult72     pub(in crate::virtio::iommu) fn handle_map_dmabuf(
73         &mut self,
74         region_id: VmMemoryRegionId,
75         gpa: u64,
76         size: u64,
77         dma_buf: File,
78     ) -> VirtioIOMMUVfioResult {
79         if gpa & self.page_mask != 0 {
80             error!("cannot map dmabuf to non-page-aligned guest physical address");
81             return VirtioIOMMUVfioResult::InvalidParam;
82         }
83         let mmap = match MemoryMappingBuilder::new(size as usize)
84             .from_file(&dma_buf)
85             .build()
86         {
87             Ok(v) => v,
88             Err(_) => {
89                 error!("failed to mmap dma_buf");
90                 return VirtioIOMMUVfioResult::InvalidParam;
91             }
92         };
93         self.dmabuf_mem.insert(
94             gpa,
95             DmabufRegionEntry {
96                 mmap,
97                 region_id,
98                 size,
99             },
100         );
101 
102         VirtioIOMMUVfioResult::Ok
103     }
104 
handle_unmap_dmabuf( &mut self, region_id: VmMemoryRegionId, ) -> VirtioIOMMUVfioResult105     pub(in crate::virtio::iommu) fn handle_unmap_dmabuf(
106         &mut self,
107         region_id: VmMemoryRegionId,
108     ) -> VirtioIOMMUVfioResult {
109         if let Some(range) = self
110             .dmabuf_mem
111             .iter()
112             .find(|(_, dmabuf_entry)| dmabuf_entry.region_id == region_id)
113             .map(|entry| *entry.0)
114         {
115             self.dmabuf_mem.remove(&range);
116             VirtioIOMMUVfioResult::Ok
117         } else {
118             VirtioIOMMUVfioResult::NoSuchMappedDmabuf
119         }
120     }
121 
handle_vfio( &mut self, vfio_cmd: VirtioIOMMUVfioCommand, ) -> VirtioIOMMUResponse122     pub(in crate::virtio::iommu) fn handle_vfio(
123         &mut self,
124         vfio_cmd: VirtioIOMMUVfioCommand,
125     ) -> VirtioIOMMUResponse {
126         use VirtioIOMMUVfioCommand::*;
127         let vfio_result = match vfio_cmd {
128             VfioDeviceAdd {
129                 wrapper_id,
130                 container,
131                 endpoint_addr,
132             } => match VfioContainer::new_from_container(container) {
133                 Ok(vfio_container) => {
134                     let wrapper =
135                         VfioWrapper::new_with_id(vfio_container, wrapper_id, self.mem.clone());
136                     self.handle_add_vfio_device(endpoint_addr, wrapper)
137                 }
138                 Err(e) => {
139                     error!("failed to verify the new container: {}", e);
140                     VirtioIOMMUVfioResult::NoAvailableContainer
141                 }
142             },
143             VfioDeviceDel { endpoint_addr } => self.handle_del_vfio_device(endpoint_addr),
144             VfioDmabufMap {
145                 region_id,
146                 gpa,
147                 size,
148                 dma_buf,
149             } => self.handle_map_dmabuf(region_id, gpa, size, File::from(dma_buf)),
150             VfioDmabufUnmap(region_id) => self.handle_unmap_dmabuf(region_id),
151         };
152         VirtioIOMMUResponse::VfioResponse(vfio_result)
153     }
154 }
155 
handle_command_tube( state: &Rc<RefCell<State>>, command_tube: AsyncTube, ) -> Result<()>156 pub(in crate::virtio::iommu) async fn handle_command_tube(
157     state: &Rc<RefCell<State>>,
158     command_tube: AsyncTube,
159 ) -> Result<()> {
160     loop {
161         match command_tube.next::<VirtioIOMMURequest>().await {
162             Ok(command) => {
163                 let response: VirtioIOMMUResponse = match command {
164                     VirtioIOMMURequest::VfioCommand(vfio_cmd) => {
165                         state.borrow_mut().handle_vfio(vfio_cmd)
166                     }
167                 };
168                 if let Err(e) = command_tube.send(response).await {
169                     error!("{}", IommuError::VirtioIOMMUResponseError(e));
170                 }
171             }
172             Err(e) => {
173                 return Err(IommuError::VirtioIOMMUReqError(e));
174             }
175         }
176     }
177 }
178 
handle_translate_request( ex: &Executor, state: &Rc<RefCell<State>>, request_tube: Option<AsyncTube>, response_tubes: Option<BTreeMap<u32, AsyncTube>>, ) -> Result<()>179 pub(in crate::virtio::iommu) async fn handle_translate_request(
180     ex: &Executor,
181     state: &Rc<RefCell<State>>,
182     request_tube: Option<AsyncTube>,
183     response_tubes: Option<BTreeMap<u32, AsyncTube>>,
184 ) -> Result<()> {
185     let request_tube = match request_tube {
186         Some(r) => r,
187         None => {
188             futures::future::pending::<()>().await;
189             return Ok(());
190         }
191     };
192     let response_tubes = response_tubes.unwrap();
193     loop {
194         let req: IommuRequest = match request_tube.next().await {
195             Ok(req) => req,
196             Err(TubeError::Disconnected) => {
197                 // This means the process on the other side of the tube went away. That's
198                 // not a problem with virtio-iommu itself, so just exit this callback
199                 // and wait for crosvm to exit.
200                 return Ok(());
201             }
202             Err(e) => {
203                 return Err(IommuError::Tube(e));
204             }
205         };
206         let resp = if let Some(mapper) = state.borrow().endpoints.get(&req.get_endpoint_id()) {
207             match req {
208                 IommuRequest::Export { iova, size, .. } => {
209                     mapper.lock().export(iova, size).map(IommuResponse::Export)
210                 }
211                 IommuRequest::Release { iova, size, .. } => mapper
212                     .lock()
213                     .release(iova, size)
214                     .map(|_| IommuResponse::Release),
215                 IommuRequest::StartExportSession { .. } => mapper
216                     .lock()
217                     .start_export_session(ex)
218                     .map(IommuResponse::StartExportSession),
219             }
220         } else {
221             error!("endpoint {} not found", req.get_endpoint_id());
222             continue;
223         };
224         let resp: IommuResponse = match resp {
225             Ok(resp) => resp,
226             Err(e) => IommuResponse::Err(format!("{:?}", e)),
227         };
228         response_tubes
229             .get(&req.get_endpoint_id())
230             .unwrap()
231             .send(resp)
232             .await
233             .map_err(IommuError::Tube)?;
234     }
235 }
236