xref: /aosp_15_r20/external/crosvm/devices/src/usb/xhci/device_slot.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2019 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::mem::size_of;
6 use std::sync::atomic::AtomicBool;
7 use std::sync::atomic::Ordering;
8 use std::sync::Arc;
9 
10 use base::debug;
11 use base::error;
12 use base::info;
13 use bit_field::Error as BitFieldError;
14 use remain::sorted;
15 use sync::Mutex;
16 use thiserror::Error;
17 use vm_memory::GuestAddress;
18 use vm_memory::GuestMemory;
19 use vm_memory::GuestMemoryError;
20 
21 use super::interrupter::Interrupter;
22 use super::transfer_ring_controller::TransferRingController;
23 use super::transfer_ring_controller::TransferRingControllerError;
24 use super::transfer_ring_controller::TransferRingControllers;
25 use super::usb_hub;
26 use super::usb_hub::UsbHub;
27 use super::xhci_abi::AddressDeviceCommandTrb;
28 use super::xhci_abi::ConfigureEndpointCommandTrb;
29 use super::xhci_abi::DequeuePtr;
30 use super::xhci_abi::DeviceContext;
31 use super::xhci_abi::DeviceSlotState;
32 use super::xhci_abi::EndpointContext;
33 use super::xhci_abi::EndpointState;
34 use super::xhci_abi::EvaluateContextCommandTrb;
35 use super::xhci_abi::InputControlContext;
36 use super::xhci_abi::SlotContext;
37 use super::xhci_abi::StreamContextArray;
38 use super::xhci_abi::TrbCompletionCode;
39 use super::xhci_abi::DEVICE_CONTEXT_ENTRY_SIZE;
40 use super::xhci_backend_device::XhciBackendDevice;
41 use super::xhci_regs::valid_max_pstreams;
42 use super::xhci_regs::valid_slot_id;
43 use super::xhci_regs::MAX_PORTS;
44 use super::xhci_regs::MAX_SLOTS;
45 use crate::register_space::Register;
46 use crate::usb::backend::error::Error as BackendProviderError;
47 use crate::usb::xhci::ring_buffer_stop_cb::fallible_closure;
48 use crate::usb::xhci::ring_buffer_stop_cb::RingBufferStopCallback;
49 use crate::utils::EventLoop;
50 use crate::utils::FailHandle;
51 
52 #[sorted]
53 #[derive(Error, Debug)]
54 pub enum Error {
55     #[error("failed to allocate streams: {0}")]
56     AllocStreams(BackendProviderError),
57     #[error("bad device context: {0}")]
58     BadDeviceContextAddr(GuestAddress),
59     #[error("bad endpoint context: {0}")]
60     BadEndpointContext(GuestAddress),
61     #[error("device slot get a bad endpoint id: {0}")]
62     BadEndpointId(u8),
63     #[error("bad input context address: {0}")]
64     BadInputContextAddr(GuestAddress),
65     #[error("device slot get a bad port id: {0}")]
66     BadPortId(u8),
67     #[error("bad stream context type: {0}")]
68     BadStreamContextType(u8),
69     #[error("callback failed")]
70     CallbackFailed,
71     #[error("failed to create transfer controller: {0}")]
72     CreateTransferController(TransferRingControllerError),
73     #[error("failed to free streams: {0}")]
74     FreeStreams(BackendProviderError),
75     #[error("failed to get endpoint state: {0}")]
76     GetEndpointState(BitFieldError),
77     #[error("failed to get port: {0}")]
78     GetPort(u8),
79     #[error("failed to get slot context state: {0}")]
80     GetSlotContextState(BitFieldError),
81     #[error("failed to get trc: {0}")]
82     GetTrc(u8),
83     #[error("failed to read guest memory: {0}")]
84     ReadGuestMemory(GuestMemoryError),
85     #[error("failed to reset port: {0}")]
86     ResetPort(BackendProviderError),
87     #[error("failed to upgrade weak reference")]
88     WeakReferenceUpgrade,
89     #[error("failed to write guest memory: {0}")]
90     WriteGuestMemory(GuestMemoryError),
91 }
92 
93 type Result<T> = std::result::Result<T, Error>;
94 
95 /// See spec 4.5.1 for dci.
96 /// index 0: Control endpoint. Device Context Index: 1.
97 /// index 1: Endpoint 1 out. Device Context Index: 2
98 /// index 2: Endpoint 1 in. Device Context Index: 3.
99 /// index 3: Endpoint 2 out. Device Context Index: 4
100 /// ...
101 /// index 30: Endpoint 15 in. Device Context Index: 31
102 pub const TRANSFER_RING_CONTROLLERS_INDEX_END: usize = 31;
103 /// End of device context index.
104 pub const DCI_INDEX_END: u8 = (TRANSFER_RING_CONTROLLERS_INDEX_END + 1) as u8;
105 /// Device context index of first transfer endpoint.
106 pub const FIRST_TRANSFER_ENDPOINT_DCI: u8 = 2;
107 
valid_endpoint_id(endpoint_id: u8) -> bool108 fn valid_endpoint_id(endpoint_id: u8) -> bool {
109     endpoint_id < DCI_INDEX_END && endpoint_id > 0
110 }
111 
112 #[derive(Clone)]
113 pub struct DeviceSlots {
114     fail_handle: Arc<dyn FailHandle>,
115     hub: Arc<UsbHub>,
116     slots: Vec<Arc<DeviceSlot>>,
117 }
118 
119 impl DeviceSlots {
new( fail_handle: Arc<dyn FailHandle>, dcbaap: Register<u64>, hub: Arc<UsbHub>, interrupter: Arc<Mutex<Interrupter>>, event_loop: Arc<EventLoop>, mem: GuestMemory, ) -> DeviceSlots120     pub fn new(
121         fail_handle: Arc<dyn FailHandle>,
122         dcbaap: Register<u64>,
123         hub: Arc<UsbHub>,
124         interrupter: Arc<Mutex<Interrupter>>,
125         event_loop: Arc<EventLoop>,
126         mem: GuestMemory,
127     ) -> DeviceSlots {
128         let mut slots = Vec::new();
129         for slot_id in 1..=MAX_SLOTS {
130             slots.push(Arc::new(DeviceSlot::new(
131                 slot_id,
132                 dcbaap.clone(),
133                 hub.clone(),
134                 interrupter.clone(),
135                 event_loop.clone(),
136                 mem.clone(),
137             )));
138         }
139         DeviceSlots {
140             fail_handle,
141             hub,
142             slots,
143         }
144     }
145 
146     /// Note that slot id starts from 1. Slot index start from 0.
slot(&self, slot_id: u8) -> Option<Arc<DeviceSlot>>147     pub fn slot(&self, slot_id: u8) -> Option<Arc<DeviceSlot>> {
148         if valid_slot_id(slot_id) {
149             Some(self.slots[slot_id as usize - 1].clone())
150         } else {
151             error!(
152                 "trying to index a wrong slot id {}, max slot = {}",
153                 slot_id, MAX_SLOTS
154             );
155             None
156         }
157     }
158 
159     /// Reset the device connected to a specific port.
reset_port(&self, port_id: u8) -> Result<()>160     pub fn reset_port(&self, port_id: u8) -> Result<()> {
161         if let Some(port) = self.hub.get_port(port_id) {
162             if let Some(backend_device) = port.backend_device().as_mut() {
163                 backend_device.lock().reset().map_err(Error::ResetPort)?;
164             }
165         }
166 
167         // No device on port, so nothing to reset.
168         Ok(())
169     }
170 
171     /// Stop all device slots and reset them.
stop_all_and_reset<C: FnMut() + 'static + Send>(&self, mut callback: C)172     pub fn stop_all_and_reset<C: FnMut() + 'static + Send>(&self, mut callback: C) {
173         info!("xhci: stopping all device slots and resetting host hub");
174         let slots = self.slots.clone();
175         let hub = self.hub.clone();
176         let auto_callback = RingBufferStopCallback::new(fallible_closure(
177             self.fail_handle.clone(),
178             move || -> std::result::Result<(), usb_hub::Error> {
179                 for slot in &slots {
180                     slot.reset();
181                 }
182                 hub.reset()?;
183                 callback();
184                 Ok(())
185             },
186         ));
187         self.stop_all(auto_callback);
188     }
189 
190     /// Stop all devices. The auto callback will be executed when all trc is stopped. It could
191     /// happen asynchronously, if there are any pending transfers.
stop_all(&self, auto_callback: RingBufferStopCallback)192     pub fn stop_all(&self, auto_callback: RingBufferStopCallback) {
193         for slot in &self.slots {
194             slot.stop_all_trc(auto_callback.clone());
195         }
196     }
197 
198     /// Disable a slot. This might happen asynchronously, if there is any pending transfers. The
199     /// callback will be invoked when slot is actually disabled.
disable_slot< C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send, >( &self, slot_id: u8, cb: C, ) -> Result<()>200     pub fn disable_slot<
201         C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
202     >(
203         &self,
204         slot_id: u8,
205         cb: C,
206     ) -> Result<()> {
207         xhci_trace!("device slot {} is being disabled", slot_id);
208         DeviceSlot::disable(
209             self.fail_handle.clone(),
210             &self.slots[slot_id as usize - 1],
211             cb,
212         )
213     }
214 
215     /// Reset a slot. This is a shortcut call for DeviceSlot::reset_slot.
reset_slot< C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send, >( &self, slot_id: u8, cb: C, ) -> Result<()>216     pub fn reset_slot<
217         C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
218     >(
219         &self,
220         slot_id: u8,
221         cb: C,
222     ) -> Result<()> {
223         xhci_trace!("device slot {} is resetting", slot_id);
224         DeviceSlot::reset_slot(
225             self.fail_handle.clone(),
226             &self.slots[slot_id as usize - 1],
227             cb,
228         )
229     }
230 
stop_endpoint< C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send, >( &self, slot_id: u8, endpoint_id: u8, cb: C, ) -> Result<()>231     pub fn stop_endpoint<
232         C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
233     >(
234         &self,
235         slot_id: u8,
236         endpoint_id: u8,
237         cb: C,
238     ) -> Result<()> {
239         self.slots[slot_id as usize - 1].stop_endpoint(self.fail_handle.clone(), endpoint_id, cb)
240     }
241 
reset_endpoint< C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send, >( &self, slot_id: u8, endpoint_id: u8, cb: C, ) -> Result<()>242     pub fn reset_endpoint<
243         C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
244     >(
245         &self,
246         slot_id: u8,
247         endpoint_id: u8,
248         cb: C,
249     ) -> Result<()> {
250         self.slots[slot_id as usize - 1].reset_endpoint(self.fail_handle.clone(), endpoint_id, cb)
251     }
252 }
253 
254 // Usb port id. Valid ids starts from 1, to MAX_PORTS.
255 struct PortId(Mutex<u8>);
256 
257 impl PortId {
new() -> Self258     fn new() -> Self {
259         PortId(Mutex::new(0))
260     }
261 
set(&self, value: u8) -> Result<()>262     fn set(&self, value: u8) -> Result<()> {
263         if !(1..=MAX_PORTS).contains(&value) {
264             return Err(Error::BadPortId(value));
265         }
266         *self.0.lock() = value;
267         Ok(())
268     }
269 
reset(&self)270     fn reset(&self) {
271         *self.0.lock() = 0;
272     }
273 
get(&self) -> Result<u8>274     fn get(&self) -> Result<u8> {
275         let val = *self.0.lock();
276         if val == 0 {
277             return Err(Error::BadPortId(val));
278         }
279         Ok(val)
280     }
281 }
282 
283 pub struct DeviceSlot {
284     slot_id: u8,
285     port_id: PortId, // Valid port id starts from 1, to MAX_PORTS.
286     dcbaap: Register<u64>,
287     hub: Arc<UsbHub>,
288     interrupter: Arc<Mutex<Interrupter>>,
289     event_loop: Arc<EventLoop>,
290     mem: GuestMemory,
291     enabled: AtomicBool,
292     transfer_ring_controllers: Mutex<Vec<Option<TransferRingControllers>>>,
293 }
294 
295 impl DeviceSlot {
296     /// Create a new device slot.
new( slot_id: u8, dcbaap: Register<u64>, hub: Arc<UsbHub>, interrupter: Arc<Mutex<Interrupter>>, event_loop: Arc<EventLoop>, mem: GuestMemory, ) -> Self297     pub fn new(
298         slot_id: u8,
299         dcbaap: Register<u64>,
300         hub: Arc<UsbHub>,
301         interrupter: Arc<Mutex<Interrupter>>,
302         event_loop: Arc<EventLoop>,
303         mem: GuestMemory,
304     ) -> Self {
305         let mut transfer_ring_controllers = Vec::new();
306         transfer_ring_controllers.resize_with(TRANSFER_RING_CONTROLLERS_INDEX_END, || None);
307         DeviceSlot {
308             slot_id,
309             port_id: PortId::new(),
310             dcbaap,
311             hub,
312             interrupter,
313             event_loop,
314             mem,
315             enabled: AtomicBool::new(false),
316             transfer_ring_controllers: Mutex::new(transfer_ring_controllers),
317         }
318     }
319 
get_trc(&self, i: usize, stream_id: u16) -> Option<Arc<TransferRingController>>320     fn get_trc(&self, i: usize, stream_id: u16) -> Option<Arc<TransferRingController>> {
321         let trcs = self.transfer_ring_controllers.lock();
322         match &trcs[i] {
323             Some(TransferRingControllers::Endpoint(trc)) => Some(trc.clone()),
324             Some(TransferRingControllers::Stream(trcs)) => {
325                 let stream_id = stream_id as usize;
326                 if stream_id > 0 && stream_id <= trcs.len() {
327                     Some(trcs[stream_id - 1].clone())
328                 } else {
329                     None
330                 }
331             }
332             None => None,
333         }
334     }
335 
get_trcs(&self, i: usize) -> Option<TransferRingControllers>336     fn get_trcs(&self, i: usize) -> Option<TransferRingControllers> {
337         let trcs = self.transfer_ring_controllers.lock();
338         trcs[i].clone()
339     }
340 
set_trcs(&self, i: usize, trc: Option<TransferRingControllers>)341     fn set_trcs(&self, i: usize, trc: Option<TransferRingControllers>) {
342         let mut trcs = self.transfer_ring_controllers.lock();
343         trcs[i] = trc;
344     }
345 
trc_len(&self) -> usize346     fn trc_len(&self) -> usize {
347         self.transfer_ring_controllers.lock().len()
348     }
349 
350     /// The arguments are identical to the fields in each doorbell register. The
351     /// target value:
352     /// 1: Reserved
353     /// 2: Control endpoint
354     /// 3: Endpoint 1 out
355     /// 4: Endpoint 1 in
356     /// 5: Endpoint 2 out
357     /// ...
358     /// 32: Endpoint 15 in
359     ///
360     /// Steam ID will be useful when host controller support streams.
361     /// The stream ID must be zero for endpoints that do not have streams
362     /// configured.
363     /// This function will return false if it fails to trigger transfer ring start.
ring_doorbell(&self, target: u8, stream_id: u16) -> Result<bool>364     pub fn ring_doorbell(&self, target: u8, stream_id: u16) -> Result<bool> {
365         if !valid_endpoint_id(target) {
366             error!(
367                 "device slot {}: Invalid target written to doorbell register. target: {}",
368                 self.slot_id, target
369             );
370             return Ok(false);
371         }
372         xhci_trace!(
373             "device slot {}: ring_doorbell target = {} stream_id = {}",
374             self.slot_id,
375             target,
376             stream_id
377         );
378         // See DCI in spec.
379         let endpoint_index = (target - 1) as usize;
380         let transfer_ring_controller = match self.get_trc(endpoint_index, stream_id) {
381             Some(tr) => tr,
382             None => {
383                 error!("Device endpoint is not inited");
384                 return Ok(false);
385             }
386         };
387         let mut context = self.get_device_context()?;
388         let endpoint_state = context.endpoint_context[endpoint_index]
389             .get_endpoint_state()
390             .map_err(Error::GetEndpointState)?;
391         if endpoint_state == EndpointState::Running || endpoint_state == EndpointState::Stopped {
392             if endpoint_state == EndpointState::Stopped {
393                 context.endpoint_context[endpoint_index].set_endpoint_state(EndpointState::Running);
394                 self.set_device_context(context)?;
395             }
396             // endpoint is started, start transfer ring
397             transfer_ring_controller.start();
398         } else {
399             error!("doorbell rung when endpoint state is {:?}", endpoint_state);
400         }
401         Ok(true)
402     }
403 
404     /// Enable the slot. This function returns false if it's already enabled.
enable(&self) -> bool405     pub fn enable(&self) -> bool {
406         let was_already_enabled = self.enabled.swap(true, Ordering::SeqCst);
407         if was_already_enabled {
408             error!("device slot is already enabled");
409         }
410         !was_already_enabled
411     }
412 
413     /// Disable this device slot. If the slot is not enabled, callback will be invoked immediately
414     /// with error. Otherwise, callback will be invoked when all trc is stopped.
disable<C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send>( fail_handle: Arc<dyn FailHandle>, slot: &Arc<DeviceSlot>, mut callback: C, ) -> Result<()>415     pub fn disable<C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send>(
416         fail_handle: Arc<dyn FailHandle>,
417         slot: &Arc<DeviceSlot>,
418         mut callback: C,
419     ) -> Result<()> {
420         if slot.enabled.load(Ordering::SeqCst) {
421             let slot_weak = Arc::downgrade(slot);
422             let auto_callback =
423                 RingBufferStopCallback::new(fallible_closure(fail_handle, move || {
424                     // Slot should still be alive when the callback is invoked. If it's not, there
425                     // must be a bug somewhere.
426                     let slot = slot_weak.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
427                     let mut device_context = slot.get_device_context()?;
428                     device_context
429                         .slot_context
430                         .set_slot_state(DeviceSlotState::DisabledOrEnabled);
431                     slot.set_device_context(device_context)?;
432                     slot.reset();
433                     debug!(
434                         "device slot {}: all trc disabled, sending trb",
435                         slot.slot_id
436                     );
437                     callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
438                 }));
439             slot.stop_all_trc(auto_callback);
440             Ok(())
441         } else {
442             callback(TrbCompletionCode::SlotNotEnabledError).map_err(|_| Error::CallbackFailed)
443         }
444     }
445 
446     // Assigns the device address and initializes slot and endpoint 0 context.
set_address( self: &Arc<Self>, trb: &AddressDeviceCommandTrb, ) -> Result<TrbCompletionCode>447     pub fn set_address(
448         self: &Arc<Self>,
449         trb: &AddressDeviceCommandTrb,
450     ) -> Result<TrbCompletionCode> {
451         if !self.enabled.load(Ordering::SeqCst) {
452             error!(
453                 "trying to set address to a disabled device slot {}",
454                 self.slot_id
455             );
456             return Ok(TrbCompletionCode::SlotNotEnabledError);
457         }
458         let device_context = self.get_device_context()?;
459         let state = device_context
460             .slot_context
461             .get_slot_state()
462             .map_err(Error::GetSlotContextState)?;
463         match state {
464             DeviceSlotState::DisabledOrEnabled => {}
465             DeviceSlotState::Default if !trb.get_block_set_address_request() => {}
466             _ => {
467                 error!("slot {} has unexpected slot state", self.slot_id);
468                 return Ok(TrbCompletionCode::ContextStateError);
469             }
470         }
471 
472         // Copy all fields of the slot context and endpoint 0 context from the input context
473         // to the output context.
474         let input_context_ptr = GuestAddress(trb.get_input_context_pointer());
475         // Copy slot context.
476         self.copy_context(input_context_ptr, 0)?;
477         // Copy control endpoint context.
478         self.copy_context(input_context_ptr, 1)?;
479 
480         // Read back device context.
481         let mut device_context = self.get_device_context()?;
482         let port_id = device_context.slot_context.get_root_hub_port_number();
483         self.port_id.set(port_id)?;
484         debug!(
485             "port id {} is assigned to slot id {}",
486             port_id, self.slot_id
487         );
488 
489         // Initialize the control endpoint. Endpoint id = 1.
490         let trc = TransferRingController::new(
491             self.mem.clone(),
492             self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?,
493             self.event_loop.clone(),
494             self.interrupter.clone(),
495             self.slot_id,
496             1,
497             Arc::downgrade(self),
498             None,
499         )
500         .map_err(Error::CreateTransferController)?;
501         self.set_trcs(0, Some(TransferRingControllers::Endpoint(trc)));
502 
503         // Assign slot ID as device address if block_set_address_request is not set.
504         if trb.get_block_set_address_request() {
505             device_context
506                 .slot_context
507                 .set_slot_state(DeviceSlotState::Default);
508         } else {
509             let port = self.hub.get_port(port_id).ok_or(Error::GetPort(port_id))?;
510             match port.backend_device().as_mut() {
511                 Some(backend) => {
512                     backend.lock().set_address(self.slot_id as u32);
513                 }
514                 None => {
515                     return Ok(TrbCompletionCode::TransactionError);
516                 }
517             }
518 
519             device_context
520                 .slot_context
521                 .set_usb_device_address(self.slot_id);
522             device_context
523                 .slot_context
524                 .set_slot_state(DeviceSlotState::Addressed);
525         }
526 
527         // TODO(jkwang) trc should always exists. Fix this.
528         self.get_trc(0, 0)
529             .ok_or(Error::GetTrc(0))?
530             .set_dequeue_pointer(
531                 device_context.endpoint_context[0]
532                     .get_tr_dequeue_pointer()
533                     .get_gpa(),
534             );
535 
536         self.get_trc(0, 0)
537             .ok_or(Error::GetTrc(0))?
538             .set_consumer_cycle_state(device_context.endpoint_context[0].get_dequeue_cycle_state());
539 
540         // Setting endpoint 0 to running
541         device_context.endpoint_context[0].set_endpoint_state(EndpointState::Running);
542         self.set_device_context(device_context)?;
543         Ok(TrbCompletionCode::Success)
544     }
545 
546     // Adds or drops multiple endpoints in the device slot.
configure_endpoint( self: &Arc<Self>, trb: &ConfigureEndpointCommandTrb, ) -> Result<TrbCompletionCode>547     pub fn configure_endpoint(
548         self: &Arc<Self>,
549         trb: &ConfigureEndpointCommandTrb,
550     ) -> Result<TrbCompletionCode> {
551         let input_control_context = if trb.get_deconfigure() {
552             // From section 4.6.6 of the xHCI spec:
553             // Setting the deconfigure (DC) flag to '1' in the Configure Endpoint Command
554             // TRB is equivalent to setting Input Context Drop Context flags 2-31 to '1'
555             // and Add Context 2-31 flags to '0'.
556             let mut c = InputControlContext::new();
557             c.set_add_context_flags(0);
558             c.set_drop_context_flags(0xfffffffc);
559             c
560         } else {
561             self.mem
562                 .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
563                 .map_err(Error::ReadGuestMemory)?
564         };
565 
566         for device_context_index in 1..DCI_INDEX_END {
567             if input_control_context.drop_context_flag(device_context_index) {
568                 self.drop_one_endpoint(device_context_index)?;
569             }
570             if input_control_context.add_context_flag(device_context_index) {
571                 self.copy_context(
572                     GuestAddress(trb.get_input_context_pointer()),
573                     device_context_index,
574                 )?;
575                 self.add_one_endpoint(device_context_index)?;
576             }
577         }
578 
579         if trb.get_deconfigure() {
580             self.set_state(DeviceSlotState::Addressed)?;
581         } else {
582             self.set_state(DeviceSlotState::Configured)?;
583         }
584         Ok(TrbCompletionCode::Success)
585     }
586 
587     // Evaluates the device context by reading new values for certain fields of
588     // the slot context and/or control endpoint context.
evaluate_context(&self, trb: &EvaluateContextCommandTrb) -> Result<TrbCompletionCode>589     pub fn evaluate_context(&self, trb: &EvaluateContextCommandTrb) -> Result<TrbCompletionCode> {
590         if !self.enabled.load(Ordering::SeqCst) {
591             return Ok(TrbCompletionCode::SlotNotEnabledError);
592         }
593         // TODO(jkwang) verify this
594         // The spec has multiple contradictions about validating context parameters in sections
595         // 4.6.7, 6.2.3.3. To keep things as simple as possible we do no further validation here.
596         let input_control_context: InputControlContext = self
597             .mem
598             .read_obj_from_addr(GuestAddress(trb.get_input_context_pointer()))
599             .map_err(Error::ReadGuestMemory)?;
600 
601         let mut device_context = self.get_device_context()?;
602         if input_control_context.add_context_flag(0) {
603             let input_slot_context: SlotContext = self
604                 .mem
605                 .read_obj_from_addr(GuestAddress(
606                     trb.get_input_context_pointer() + DEVICE_CONTEXT_ENTRY_SIZE as u64,
607                 ))
608                 .map_err(Error::ReadGuestMemory)?;
609             device_context
610                 .slot_context
611                 .set_interrupter_target(input_slot_context.get_interrupter_target());
612 
613             device_context
614                 .slot_context
615                 .set_max_exit_latency(input_slot_context.get_max_exit_latency());
616         }
617 
618         // From 6.2.3.3: "Endpoint Contexts 2 throught 31 shall not be evaluated by the Evaluate
619         // Context Command".
620         if input_control_context.add_context_flag(1) {
621             let ep0_context: EndpointContext = self
622                 .mem
623                 .read_obj_from_addr(GuestAddress(
624                     trb.get_input_context_pointer() + 2 * DEVICE_CONTEXT_ENTRY_SIZE as u64,
625                 ))
626                 .map_err(Error::ReadGuestMemory)?;
627             device_context.endpoint_context[0]
628                 .set_max_packet_size(ep0_context.get_max_packet_size());
629         }
630         self.set_device_context(device_context)?;
631         Ok(TrbCompletionCode::Success)
632     }
633 
634     /// Reset the device slot to default state and deconfigures all but the
635     /// control endpoint.
reset_slot< C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send, >( fail_handle: Arc<dyn FailHandle>, slot: &Arc<DeviceSlot>, mut callback: C, ) -> Result<()>636     pub fn reset_slot<
637         C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
638     >(
639         fail_handle: Arc<dyn FailHandle>,
640         slot: &Arc<DeviceSlot>,
641         mut callback: C,
642     ) -> Result<()> {
643         let weak_s = Arc::downgrade(slot);
644         let auto_callback =
645             RingBufferStopCallback::new(fallible_closure(fail_handle, move || -> Result<()> {
646                 let s = weak_s.upgrade().ok_or(Error::WeakReferenceUpgrade)?;
647                 for i in FIRST_TRANSFER_ENDPOINT_DCI..DCI_INDEX_END {
648                     s.drop_one_endpoint(i)?;
649                 }
650                 let mut ctx = s.get_device_context()?;
651                 ctx.slot_context.set_slot_state(DeviceSlotState::Default);
652                 ctx.slot_context.set_context_entries(1);
653                 ctx.slot_context.set_root_hub_port_number(0);
654                 s.set_device_context(ctx)?;
655                 callback(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)?;
656                 Ok(())
657             }));
658         slot.stop_all_trc(auto_callback);
659         Ok(())
660     }
661 
662     /// Stop all transfer ring controllers.
stop_all_trc(&self, auto_callback: RingBufferStopCallback)663     pub fn stop_all_trc(&self, auto_callback: RingBufferStopCallback) {
664         for i in 0..self.trc_len() {
665             if let Some(trcs) = self.get_trcs(i) {
666                 match trcs {
667                     TransferRingControllers::Endpoint(trc) => {
668                         trc.stop(auto_callback.clone());
669                     }
670                     TransferRingControllers::Stream(trcs) => {
671                         for trc in trcs {
672                             trc.stop(auto_callback.clone());
673                         }
674                     }
675                 }
676             }
677         }
678     }
679 
680     /// Stop an endpoint.
stop_endpoint< C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send, >( &self, fail_handle: Arc<dyn FailHandle>, endpoint_id: u8, mut cb: C, ) -> Result<()>681     pub fn stop_endpoint<
682         C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
683     >(
684         &self,
685         fail_handle: Arc<dyn FailHandle>,
686         endpoint_id: u8,
687         mut cb: C,
688     ) -> Result<()> {
689         if !valid_endpoint_id(endpoint_id) {
690             error!("trb indexing wrong endpoint id");
691             return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
692         }
693         let index = endpoint_id - 1;
694         let mut device_context = self.get_device_context()?;
695         let endpoint_context = &mut device_context.endpoint_context[index as usize];
696         match self.get_trcs(index as usize) {
697             Some(TransferRingControllers::Endpoint(trc)) => {
698                 let auto_cb = RingBufferStopCallback::new(fallible_closure(
699                     fail_handle,
700                     move || -> Result<()> {
701                         cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
702                     },
703                 ));
704                 trc.stop(auto_cb);
705                 let dequeue_pointer = trc.get_dequeue_pointer();
706                 let dcs = trc.get_consumer_cycle_state();
707                 endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
708                 endpoint_context.set_dequeue_cycle_state(dcs);
709             }
710             Some(TransferRingControllers::Stream(trcs)) => {
711                 let stream_context_array_addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
712                 let mut stream_context_array: StreamContextArray = self
713                     .mem
714                     .read_obj_from_addr(stream_context_array_addr)
715                     .map_err(Error::ReadGuestMemory)?;
716                 let auto_cb = RingBufferStopCallback::new(fallible_closure(
717                     fail_handle,
718                     move || -> Result<()> {
719                         cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
720                     },
721                 ));
722                 for (i, trc) in trcs.iter().enumerate() {
723                     let dequeue_pointer = trc.get_dequeue_pointer();
724                     let dcs = trc.get_consumer_cycle_state();
725                     trc.stop(auto_cb.clone());
726                     stream_context_array.stream_contexts[i + 1]
727                         .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
728                     stream_context_array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
729                 }
730                 self.mem
731                     .write_obj_at_addr(stream_context_array, stream_context_array_addr)
732                     .map_err(Error::WriteGuestMemory)?;
733             }
734             None => {
735                 error!("endpoint at index {} is not started", index);
736                 cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed)?;
737             }
738         }
739         endpoint_context.set_endpoint_state(EndpointState::Stopped);
740         self.set_device_context(device_context)?;
741         Ok(())
742     }
743 
744     /// Reset an endpoint.
reset_endpoint< C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send, >( &self, fail_handle: Arc<dyn FailHandle>, endpoint_id: u8, mut cb: C, ) -> Result<()>745     pub fn reset_endpoint<
746         C: FnMut(TrbCompletionCode) -> std::result::Result<(), ()> + 'static + Send,
747     >(
748         &self,
749         fail_handle: Arc<dyn FailHandle>,
750         endpoint_id: u8,
751         mut cb: C,
752     ) -> Result<()> {
753         if !valid_endpoint_id(endpoint_id) {
754             error!("trb indexing wrong endpoint id");
755             return cb(TrbCompletionCode::TrbError).map_err(|_| Error::CallbackFailed);
756         }
757         let index = endpoint_id - 1;
758         let mut device_context = self.get_device_context()?;
759         let endpoint_context = &mut device_context.endpoint_context[index as usize];
760         if endpoint_context
761             .get_endpoint_state()
762             .map_err(Error::GetEndpointState)?
763             != EndpointState::Halted
764         {
765             error!("endpoint at index {} is not halted", index);
766             return cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed);
767         }
768         match self.get_trcs(index as usize) {
769             Some(TransferRingControllers::Endpoint(trc)) => {
770                 let auto_cb = RingBufferStopCallback::new(fallible_closure(
771                     fail_handle,
772                     move || -> Result<()> {
773                         cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
774                     },
775                 ));
776                 trc.stop(auto_cb);
777                 let dequeue_pointer = trc.get_dequeue_pointer();
778                 let dcs = trc.get_consumer_cycle_state();
779                 endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
780                 endpoint_context.set_dequeue_cycle_state(dcs);
781             }
782             Some(TransferRingControllers::Stream(trcs)) => {
783                 let stream_context_array_addr = endpoint_context.get_tr_dequeue_pointer().get_gpa();
784                 let mut stream_context_array: StreamContextArray = self
785                     .mem
786                     .read_obj_from_addr(stream_context_array_addr)
787                     .map_err(Error::ReadGuestMemory)?;
788                 let auto_cb = RingBufferStopCallback::new(fallible_closure(
789                     fail_handle,
790                     move || -> Result<()> {
791                         cb(TrbCompletionCode::Success).map_err(|_| Error::CallbackFailed)
792                     },
793                 ));
794                 for (i, trc) in trcs.iter().enumerate() {
795                     let dequeue_pointer = trc.get_dequeue_pointer();
796                     let dcs = trc.get_consumer_cycle_state();
797                     trc.stop(auto_cb.clone());
798                     stream_context_array.stream_contexts[i + 1]
799                         .set_tr_dequeue_pointer(DequeuePtr::new(dequeue_pointer));
800                     stream_context_array.stream_contexts[i + 1].set_dequeue_cycle_state(dcs);
801                 }
802                 self.mem
803                     .write_obj_at_addr(stream_context_array, stream_context_array_addr)
804                     .map_err(Error::WriteGuestMemory)?;
805             }
806             None => {
807                 error!("endpoint at index {} is not started", index);
808                 cb(TrbCompletionCode::ContextStateError).map_err(|_| Error::CallbackFailed)?;
809             }
810         }
811         endpoint_context.set_endpoint_state(EndpointState::Stopped);
812         self.set_device_context(device_context)?;
813         Ok(())
814     }
815 
816     /// Set transfer ring dequeue pointer.
set_tr_dequeue_ptr( &self, endpoint_id: u8, stream_id: u16, ptr: u64, ) -> Result<TrbCompletionCode>817     pub fn set_tr_dequeue_ptr(
818         &self,
819         endpoint_id: u8,
820         stream_id: u16,
821         ptr: u64,
822     ) -> Result<TrbCompletionCode> {
823         if !valid_endpoint_id(endpoint_id) {
824             error!("trb indexing wrong endpoint id");
825             return Ok(TrbCompletionCode::TrbError);
826         }
827         let index = (endpoint_id - 1) as usize;
828         match self.get_trc(index, stream_id) {
829             Some(trc) => {
830                 trc.set_dequeue_pointer(GuestAddress(ptr));
831                 let mut ctx = self.get_device_context()?;
832                 ctx.endpoint_context[index]
833                     .set_tr_dequeue_pointer(DequeuePtr::new(GuestAddress(ptr)));
834                 self.set_device_context(ctx)?;
835                 Ok(TrbCompletionCode::Success)
836             }
837             None => {
838                 error!("set tr dequeue ptr failed due to no trc started");
839                 Ok(TrbCompletionCode::ContextStateError)
840             }
841         }
842     }
843 
844     // Reset and reset_slot are different.
845     // Reset_slot handles command ring `reset slot` command. It will reset the slot state.
846     // Reset handles xhci reset. It will destroy everything.
reset(&self)847     fn reset(&self) {
848         for i in 0..self.trc_len() {
849             self.set_trcs(i, None);
850         }
851         debug!("resetting device slot {}!", self.slot_id);
852         self.enabled.store(false, Ordering::SeqCst);
853         self.port_id.reset();
854     }
855 
create_stream_trcs( self: &Arc<Self>, stream_context_array_addr: GuestAddress, max_pstreams: u8, device_context_index: u8, ) -> Result<TransferRingControllers>856     fn create_stream_trcs(
857         self: &Arc<Self>,
858         stream_context_array_addr: GuestAddress,
859         max_pstreams: u8,
860         device_context_index: u8,
861     ) -> Result<TransferRingControllers> {
862         let pstreams = 1usize << (max_pstreams + 1);
863         let stream_context_array: StreamContextArray = self
864             .mem
865             .read_obj_from_addr(stream_context_array_addr)
866             .map_err(Error::ReadGuestMemory)?;
867         let mut trcs = Vec::new();
868 
869         // Stream ID 0 is reserved (xHCI spec Section 4.12.2)
870         for i in 1..pstreams {
871             let stream_context = &stream_context_array.stream_contexts[i];
872             let context_type = stream_context.get_stream_context_type();
873             if context_type != 1 {
874                 // We only support Linear Stream Context Array for now
875                 return Err(Error::BadStreamContextType(context_type));
876             }
877             let trc = TransferRingController::new(
878                 self.mem.clone(),
879                 self.hub
880                     .get_port(self.port_id.get()?)
881                     .ok_or(Error::GetPort(self.port_id.get()?))?,
882                 self.event_loop.clone(),
883                 self.interrupter.clone(),
884                 self.slot_id,
885                 device_context_index,
886                 Arc::downgrade(self),
887                 Some(i as u16),
888             )
889             .map_err(Error::CreateTransferController)?;
890             trc.set_dequeue_pointer(stream_context.get_tr_dequeue_pointer().get_gpa());
891             trc.set_consumer_cycle_state(stream_context.get_dequeue_cycle_state());
892             trcs.push(trc);
893         }
894         Ok(TransferRingControllers::Stream(trcs))
895     }
896 
add_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()>897     fn add_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
898         xhci_trace!(
899             "adding one endpoint, device context index {}",
900             device_context_index
901         );
902         let mut device_context = self.get_device_context()?;
903         let transfer_ring_index = (device_context_index - 1) as usize;
904         let endpoint_context = &mut device_context.endpoint_context[transfer_ring_index];
905         let max_pstreams = endpoint_context.get_max_primary_streams();
906         let tr_dequeue_pointer = endpoint_context.get_tr_dequeue_pointer().get_gpa();
907         let endpoint_context_addr = self
908             .get_device_context_addr()?
909             .unchecked_add(size_of::<SlotContext>() as u64)
910             .unchecked_add(size_of::<EndpointContext>() as u64 * transfer_ring_index as u64);
911         let trcs = if max_pstreams > 0 {
912             if !valid_max_pstreams(max_pstreams) {
913                 return Err(Error::BadEndpointContext(endpoint_context_addr));
914             }
915             let endpoint_type = endpoint_context.get_endpoint_type();
916             if endpoint_type != 2 && endpoint_type != 6 {
917                 // Stream is only supported on a bulk endpoint
918                 return Err(Error::BadEndpointId(transfer_ring_index as u8));
919             }
920             if endpoint_context.get_linear_stream_array() != 1 {
921                 // We only support Linear Stream Context Array for now
922                 return Err(Error::BadEndpointContext(endpoint_context_addr));
923             }
924 
925             let trcs =
926                 self.create_stream_trcs(tr_dequeue_pointer, max_pstreams, device_context_index)?;
927 
928             if let Some(port) = self.hub.get_port(self.port_id.get()?) {
929                 if let Some(backend_device) = port.backend_device().as_mut() {
930                     let mut endpoint_address = device_context_index / 2;
931                     if device_context_index % 2 == 1 {
932                         endpoint_address |= 1u8 << 7;
933                     }
934                     let streams = 1 << (max_pstreams + 1);
935                     // Subtracting 1 is to ignore Stream ID 0
936                     backend_device
937                         .lock()
938                         .alloc_streams(endpoint_address, streams - 1)
939                         .map_err(Error::AllocStreams)?;
940                 }
941             }
942             trcs
943         } else {
944             let trc = TransferRingController::new(
945                 self.mem.clone(),
946                 self.hub
947                     .get_port(self.port_id.get()?)
948                     .ok_or(Error::GetPort(self.port_id.get()?))?,
949                 self.event_loop.clone(),
950                 self.interrupter.clone(),
951                 self.slot_id,
952                 device_context_index,
953                 Arc::downgrade(self),
954                 None,
955             )
956             .map_err(Error::CreateTransferController)?;
957             trc.set_dequeue_pointer(tr_dequeue_pointer);
958             trc.set_consumer_cycle_state(endpoint_context.get_dequeue_cycle_state());
959             TransferRingControllers::Endpoint(trc)
960         };
961         self.set_trcs(transfer_ring_index, Some(trcs));
962         endpoint_context.set_endpoint_state(EndpointState::Running);
963         self.set_device_context(device_context)
964     }
965 
drop_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()>966     fn drop_one_endpoint(self: &Arc<Self>, device_context_index: u8) -> Result<()> {
967         let endpoint_index = (device_context_index - 1) as usize;
968         let mut device_context = self.get_device_context()?;
969         let endpoint_context = &mut device_context.endpoint_context[endpoint_index];
970         if endpoint_context.get_max_primary_streams() > 0 {
971             if let Some(port) = self.hub.get_port(self.port_id.get()?) {
972                 if let Some(backend_device) = port.backend_device().as_mut() {
973                     let mut endpoint_address = device_context_index / 2;
974                     if device_context_index % 2 == 1 {
975                         endpoint_address |= 1u8 << 7;
976                     }
977                     backend_device
978                         .lock()
979                         .free_streams(endpoint_address)
980                         .map_err(Error::FreeStreams)?;
981                 }
982             }
983         }
984         self.set_trcs(endpoint_index, None);
985         endpoint_context.set_endpoint_state(EndpointState::Disabled);
986         self.set_device_context(device_context)
987     }
988 
get_device_context(&self) -> Result<DeviceContext>989     fn get_device_context(&self) -> Result<DeviceContext> {
990         let ctx = self
991             .mem
992             .read_obj_from_addr(self.get_device_context_addr()?)
993             .map_err(Error::ReadGuestMemory)?;
994         Ok(ctx)
995     }
996 
set_device_context(&self, device_context: DeviceContext) -> Result<()>997     fn set_device_context(&self, device_context: DeviceContext) -> Result<()> {
998         self.mem
999             .write_obj_at_addr(device_context, self.get_device_context_addr()?)
1000             .map_err(Error::WriteGuestMemory)
1001     }
1002 
copy_context( &self, input_context_ptr: GuestAddress, device_context_index: u8, ) -> Result<()>1003     fn copy_context(
1004         &self,
1005         input_context_ptr: GuestAddress,
1006         device_context_index: u8,
1007     ) -> Result<()> {
1008         // Note that it could be slot context or device context. They have the same size. Won't
1009         // make a difference here.
1010         let ctx: EndpointContext = self
1011             .mem
1012             .read_obj_from_addr(
1013                 input_context_ptr
1014                     .checked_add(
1015                         (device_context_index as u64 + 1) * DEVICE_CONTEXT_ENTRY_SIZE as u64,
1016                     )
1017                     .ok_or(Error::BadInputContextAddr(input_context_ptr))?,
1018             )
1019             .map_err(Error::ReadGuestMemory)?;
1020         xhci_trace!("copy_context {:?}", ctx);
1021         let device_context_ptr = self.get_device_context_addr()?;
1022         self.mem
1023             .write_obj_at_addr(
1024                 ctx,
1025                 device_context_ptr
1026                     .checked_add(device_context_index as u64 * DEVICE_CONTEXT_ENTRY_SIZE as u64)
1027                     .ok_or(Error::BadDeviceContextAddr(device_context_ptr))?,
1028             )
1029             .map_err(Error::WriteGuestMemory)
1030     }
1031 
get_device_context_addr(&self) -> Result<GuestAddress>1032     fn get_device_context_addr(&self) -> Result<GuestAddress> {
1033         let addr: u64 = self
1034             .mem
1035             .read_obj_from_addr(GuestAddress(
1036                 self.dcbaap.get_value() + size_of::<u64>() as u64 * self.slot_id as u64,
1037             ))
1038             .map_err(Error::ReadGuestMemory)?;
1039         Ok(GuestAddress(addr))
1040     }
1041 
set_state(&self, state: DeviceSlotState) -> Result<()>1042     fn set_state(&self, state: DeviceSlotState) -> Result<()> {
1043         let mut ctx = self.get_device_context()?;
1044         ctx.slot_context.set_slot_state(state);
1045         self.set_device_context(ctx)
1046     }
1047 
halt_endpoint(&self, endpoint_id: u8) -> Result<()>1048     pub fn halt_endpoint(&self, endpoint_id: u8) -> Result<()> {
1049         if !valid_endpoint_id(endpoint_id) {
1050             return Err(Error::BadEndpointId(endpoint_id));
1051         }
1052         let index = endpoint_id - 1;
1053         let mut device_context = self.get_device_context()?;
1054         let endpoint_context = &mut device_context.endpoint_context[index as usize];
1055         match self.get_trcs(index as usize) {
1056             Some(trcs) => match trcs {
1057                 TransferRingControllers::Endpoint(trc) => {
1058                     endpoint_context
1059                         .set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer()));
1060                     endpoint_context.set_dequeue_cycle_state(trc.get_consumer_cycle_state());
1061                 }
1062                 TransferRingControllers::Stream(trcs) => {
1063                     let stream_context_array_addr =
1064                         endpoint_context.get_tr_dequeue_pointer().get_gpa();
1065                     let mut stream_context_array: StreamContextArray = self
1066                         .mem
1067                         .read_obj_from_addr(stream_context_array_addr)
1068                         .map_err(Error::ReadGuestMemory)?;
1069                     for (i, trc) in trcs.iter().enumerate() {
1070                         stream_context_array.stream_contexts[i + 1]
1071                             .set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer()));
1072                         stream_context_array.stream_contexts[i + 1]
1073                             .set_dequeue_cycle_state(trc.get_consumer_cycle_state());
1074                     }
1075                     self.mem
1076                         .write_obj_at_addr(stream_context_array, stream_context_array_addr)
1077                         .map_err(Error::WriteGuestMemory)?;
1078                 }
1079             },
1080             None => {
1081                 error!("trc for endpoint {} not found", endpoint_id);
1082                 return Err(Error::BadEndpointId(endpoint_id));
1083             }
1084         }
1085         endpoint_context.set_endpoint_state(EndpointState::Halted);
1086         self.set_device_context(device_context)?;
1087         Ok(())
1088     }
1089 }
1090