1 // Copyright 2024 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Copied from ChromiumOS with relicensing:
16 // src/platform2/vm_tools/chunnel/src/bin/chunneld.rs
17 
18 //! Host-side stream socket forwarder
19 
20 use std::collections::btree_map::Entry as BTreeMapEntry;
21 use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque};
22 use std::fmt;
23 use std::io;
24 use std::net::{Ipv4Addr, Ipv6Addr, TcpListener};
25 use std::os::unix::io::AsRawFd;
26 use std::result;
27 use std::sync::{Arc, LazyLock, Mutex};
28 use std::time::Duration;
29 
30 use forwarder::forwarder::ForwarderSession;
31 use jni::objects::{JIntArray, JObject, JValue};
32 use jni::sys::jint;
33 use jni::JNIEnv;
34 use log::{debug, error, info, warn};
35 use nix::sys::eventfd::EventFd;
36 use poll_token_derive::PollToken;
37 use vmm_sys_util::poll::{PollContext, PollToken};
38 use vsock::VsockListener;
39 use vsock::VMADDR_CID_ANY;
40 
41 const CHUNNEL_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
42 
43 const VMADDR_PORT_ANY: u32 = u32::MAX;
44 
45 static SHUTDOWN_EVT: LazyLock<EventFd> =
46     LazyLock::new(|| EventFd::new().expect("Could not create shutdown eventfd"));
47 
48 static UPDATE_EVT: LazyLock<EventFd> =
49     LazyLock::new(|| EventFd::new().expect("Could not create update eventfd"));
50 
51 static UPDATE_QUEUE: LazyLock<Arc<Mutex<VecDeque<u16>>>> =
52     LazyLock::new(|| Arc::new(Mutex::new(VecDeque::new())));
53 
54 #[remain::sorted]
55 #[derive(Debug)]
56 enum Error {
57     BindVsock(io::Error),
58     IncorrectCid(u32),
59     LaunchForwarderGuest(jni::errors::Error),
60     NoListenerForPort(u16),
61     NoSessionForTag(SessionTag),
62     PollContextAdd(vmm_sys_util::errno::Error),
63     PollContextDelete(vmm_sys_util::errno::Error),
64     PollContextNew(vmm_sys_util::errno::Error),
65     PollWait(vmm_sys_util::errno::Error),
66     SetVsockNonblocking(io::Error),
67     TcpAccept(io::Error),
68     TcpListenerPort(io::Error),
69     UpdateEventRead(nix::Error),
70     VsockAccept(io::Error),
71     VsockAcceptTimeout,
72     VsockListenerPort(io::Error),
73 }
74 
75 type Result<T> = result::Result<T, Error>;
76 
77 impl fmt::Display for Error {
78     #[remain::check]
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result79     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
80         use self::Error::*;
81 
82         #[remain::sorted]
83         match self {
84             BindVsock(e) => write!(f, "failed to bind vsock: {}", e),
85             IncorrectCid(cid) => write!(f, "chunnel connection from unexpected cid {}", cid),
86             LaunchForwarderGuest(e) => write!(f, "failed to launch forwarder_guest {}", e),
87             NoListenerForPort(port) => write!(f, "could not find listener for port: {}", port),
88             NoSessionForTag(tag) => write!(f, "could not find session for tag: {:x}", tag),
89             PollContextAdd(e) => write!(f, "failed to add fd to poll context: {}", e),
90             PollContextDelete(e) => write!(f, "failed to delete fd from poll context: {}", e),
91             PollContextNew(e) => write!(f, "failed to create poll context: {}", e),
92             PollWait(e) => write!(f, "failed to wait for poll: {}", e),
93             SetVsockNonblocking(e) => write!(f, "failed to set vsock to nonblocking: {}", e),
94             TcpAccept(e) => write!(f, "failed to accept tcp: {}", e),
95             TcpListenerPort(e) => {
96                 write!(f, "failed to read local sockaddr for tcp listener: {}", e)
97             }
98             UpdateEventRead(e) => write!(f, "failed to read update eventfd: {}", e),
99             VsockAccept(e) => write!(f, "failed to accept vsock: {}", e),
100             VsockAcceptTimeout => write!(f, "timed out waiting for vsock connection"),
101             VsockListenerPort(e) => write!(f, "failed to get vsock listener port: {}", e),
102         }
103     }
104 }
105 
106 /// A tag that uniquely identifies a particular forwarding session. This has arbitrarily been
107 /// chosen as the fd of the local (TCP) socket.
108 type SessionTag = u32;
109 
110 /// Implements PollToken for chunneld's main poll loop.
111 #[derive(Clone, Copy, PollToken)]
112 enum Token {
113     Shutdown,
114     UpdatePorts,
115     Ipv4Listener(u16),
116     Ipv6Listener(u16),
117     LocalSocket(SessionTag),
118     RemoteSocket(SessionTag),
119 }
120 
121 /// PortListeners includes all listeners (IPv4 and IPv6) for a given port, and the target
122 /// container.
123 struct PortListeners {
124     tcp4_listener: TcpListener,
125     tcp6_listener: TcpListener,
126 }
127 
128 /// SocketFamily specifies whether a socket uses IPv4 or IPv6.
129 enum SocketFamily {
130     Ipv4,
131     Ipv6,
132 }
133 
134 /// ForwarderSessions encapsulates all forwarding state for chunneld.
135 struct ForwarderSessions<'a> {
136     listening_ports: BTreeMap<u16, PortListeners>,
137     tcp4_forwarders: HashMap<SessionTag, ForwarderSession>,
138     cid: u32,
139     jni_env: JNIEnv<'a>,
140     jni_cb: JObject<'a>,
141 }
142 
143 impl<'a> ForwarderSessions<'a> {
144     /// Creates a new instance of ForwarderSessions.
new(cid: i32, jni_env: JNIEnv<'a>, jni_cb: JObject<'a>) -> Result<Self>145     fn new(cid: i32, jni_env: JNIEnv<'a>, jni_cb: JObject<'a>) -> Result<Self> {
146         Ok(ForwarderSessions {
147             listening_ports: BTreeMap::new(),
148             tcp4_forwarders: HashMap::new(),
149             cid: cid as u32,
150             jni_env,
151             jni_cb,
152         })
153     }
154 
155     /// Adds or removes listeners based on the latest listening ports from the D-Bus thread.
process_update_queue(&mut self, poll_ctx: &PollContext<Token>) -> Result<()>156     fn process_update_queue(&mut self, poll_ctx: &PollContext<Token>) -> Result<()> {
157         // Unwrap of LockResult is customary.
158         let mut update_queue = UPDATE_QUEUE.lock().unwrap();
159         let mut active_ports: BTreeSet<u16> = BTreeSet::new();
160 
161         // Add any new listeners first.
162         while let Some(port) = update_queue.pop_front() {
163             // Ignore privileged ports.
164             if port < 1024 {
165                 continue;
166             }
167             if let BTreeMapEntry::Vacant(o) = self.listening_ports.entry(port) {
168                 // Failing to bind a port is not fatal, but we should log it.
169                 // Both IPv4 and IPv6 localhost must be bound since the host may resolve
170                 // "localhost" to either.
171                 let tcp4_listener = match TcpListener::bind((Ipv4Addr::LOCALHOST, port)) {
172                     Ok(listener) => listener,
173                     Err(e) => {
174                         warn!("failed to bind TCPv4 port: {}", e);
175                         continue;
176                     }
177                 };
178                 let tcp6_listener = match TcpListener::bind((Ipv6Addr::LOCALHOST, port)) {
179                     Ok(listener) => listener,
180                     Err(e) => {
181                         warn!("failed to bind TCPv6 port: {}", e);
182                         continue;
183                     }
184                 };
185                 poll_ctx
186                     .add(&tcp4_listener, Token::Ipv4Listener(port))
187                     .map_err(Error::PollContextAdd)?;
188                 poll_ctx
189                     .add(&tcp6_listener, Token::Ipv6Listener(port))
190                     .map_err(Error::PollContextAdd)?;
191                 o.insert(PortListeners { tcp4_listener, tcp6_listener });
192             }
193             active_ports.insert(port);
194         }
195 
196         // Iterate over the existing listeners; if the port is no longer in the
197         // listener list, remove it.
198         let old_ports: Vec<u16> = self.listening_ports.keys().cloned().collect();
199         for port in old_ports.iter() {
200             if !active_ports.contains(port) {
201                 // Remove the PortListeners struct first - on error we want to drop it and the
202                 // fds it contains.
203                 let _listening_port = self.listening_ports.remove(port);
204             }
205         }
206 
207         // Consume the eventfd.
208         UPDATE_EVT.read().map_err(Error::UpdateEventRead)?;
209 
210         Ok(())
211     }
212 
accept_connection( &mut self, poll_ctx: &PollContext<Token>, port: u16, sock_family: SocketFamily, ) -> Result<()>213     fn accept_connection(
214         &mut self,
215         poll_ctx: &PollContext<Token>,
216         port: u16,
217         sock_family: SocketFamily,
218     ) -> Result<()> {
219         let port_listeners =
220             self.listening_ports.get(&port).ok_or(Error::NoListenerForPort(port))?;
221 
222         let listener = match sock_family {
223             SocketFamily::Ipv4 => &port_listeners.tcp4_listener,
224             SocketFamily::Ipv6 => &port_listeners.tcp6_listener,
225         };
226 
227         // This session should be dropped if any of the PollContext setup fails. Since the only
228         // extant fds for the underlying sockets will be closed, they will be unregistered from
229         // epoll set automatically.
230         let session =
231             create_forwarder_session(listener, self.cid, &mut self.jni_env, &self.jni_cb)?;
232 
233         let tag = session.local_stream().as_raw_fd() as u32;
234 
235         poll_ctx
236             .add(session.local_stream(), Token::LocalSocket(tag))
237             .map_err(Error::PollContextAdd)?;
238         poll_ctx
239             .add(session.remote_stream(), Token::RemoteSocket(tag))
240             .map_err(Error::PollContextAdd)?;
241 
242         self.tcp4_forwarders.insert(tag, session);
243 
244         Ok(())
245     }
246 
forward_from_local(&mut self, poll_ctx: &PollContext<Token>, tag: SessionTag) -> Result<()>247     fn forward_from_local(&mut self, poll_ctx: &PollContext<Token>, tag: SessionTag) -> Result<()> {
248         let session = self.tcp4_forwarders.get_mut(&tag).ok_or(Error::NoSessionForTag(tag))?;
249         let shutdown = session.forward_from_local().unwrap_or(true);
250         if shutdown {
251             poll_ctx.delete(session.local_stream()).map_err(Error::PollContextDelete)?;
252             if session.is_shut_down() {
253                 self.tcp4_forwarders.remove(&tag);
254             }
255         }
256 
257         Ok(())
258     }
259 
forward_from_remote( &mut self, poll_ctx: &PollContext<Token>, tag: SessionTag, ) -> Result<()>260     fn forward_from_remote(
261         &mut self,
262         poll_ctx: &PollContext<Token>,
263         tag: SessionTag,
264     ) -> Result<()> {
265         let session = self.tcp4_forwarders.get_mut(&tag).ok_or(Error::NoSessionForTag(tag))?;
266         let shutdown = session.forward_from_remote().unwrap_or(true);
267         if shutdown {
268             poll_ctx.delete(session.remote_stream()).map_err(Error::PollContextDelete)?;
269             if session.is_shut_down() {
270                 self.tcp4_forwarders.remove(&tag);
271             }
272         }
273 
274         Ok(())
275     }
276 
run(&mut self) -> Result<()>277     fn run(&mut self) -> Result<()> {
278         let poll_ctx: PollContext<Token> = PollContext::new().map_err(Error::PollContextNew)?;
279         poll_ctx.add(&*UPDATE_EVT, Token::UpdatePorts).map_err(Error::PollContextAdd)?;
280         poll_ctx.add(&*SHUTDOWN_EVT, Token::Shutdown).map_err(Error::PollContextAdd)?;
281 
282         loop {
283             let events = poll_ctx.wait().map_err(Error::PollWait)?;
284 
285             for event in events.iter_readable() {
286                 match event.token() {
287                     Token::Shutdown => {
288                         return Ok(());
289                     }
290                     Token::UpdatePorts => {
291                         if let Err(e) = self.process_update_queue(&poll_ctx) {
292                             error!("error updating listening ports: {}", e);
293                         }
294                     }
295                     Token::Ipv4Listener(port) => {
296                         if let Err(e) = self.accept_connection(&poll_ctx, port, SocketFamily::Ipv4)
297                         {
298                             error!("error accepting connection: {}", e);
299                         }
300                     }
301                     Token::Ipv6Listener(port) => {
302                         if let Err(e) = self.accept_connection(&poll_ctx, port, SocketFamily::Ipv6)
303                         {
304                             error!("error accepting connection: {}", e);
305                         }
306                     }
307                     Token::LocalSocket(tag) => {
308                         if let Err(e) = self.forward_from_local(&poll_ctx, tag) {
309                             error!("error forwarding local traffic: {}", e);
310                         }
311                     }
312                     Token::RemoteSocket(tag) => {
313                         if let Err(e) = self.forward_from_remote(&poll_ctx, tag) {
314                             error!("error forwarding remote traffic: {}", e);
315                         }
316                     }
317                 }
318             }
319         }
320     }
321 }
322 
323 /// Creates a forwarder session from a `listener` that has a pending connection to accept.
create_forwarder_session( listener: &TcpListener, cid: u32, jni_env: &mut JNIEnv, jni_cb: &JObject, ) -> Result<ForwarderSession>324 fn create_forwarder_session(
325     listener: &TcpListener,
326     cid: u32,
327     jni_env: &mut JNIEnv,
328     jni_cb: &JObject,
329 ) -> Result<ForwarderSession> {
330     let (tcp_stream, _) = listener.accept().map_err(Error::TcpAccept)?;
331     // Bind a vsock port, tell the guest to connect, and accept the connection.
332     let vsock_listener = VsockListener::bind_with_cid_port(VMADDR_CID_ANY, VMADDR_PORT_ANY)
333         .map_err(Error::BindVsock)?;
334     vsock_listener.set_nonblocking(true).map_err(Error::SetVsockNonblocking)?;
335 
336     let tcp4_port = listener.local_addr().map_err(Error::TcpListenerPort)?.port();
337     let vsock_port = vsock_listener.local_addr().map_err(Error::VsockListenerPort)?.port();
338     jni_env
339         .call_method(
340             jni_cb,
341             "onForwardingRequestReceived",
342             "(II)V",
343             &[JValue::Int(tcp4_port.into()), JValue::Int(vsock_port as i32)],
344         )
345         .map_err(Error::LaunchForwarderGuest)?;
346 
347     #[derive(PollToken)]
348     enum Token {
349         VsockAccept,
350     }
351 
352     let poll_ctx: PollContext<Token> = PollContext::new().map_err(Error::PollContextNew)?;
353     poll_ctx.add(&vsock_listener, Token::VsockAccept).map_err(Error::PollContextAdd)?;
354 
355     // Wait a few seconds for the guest to connect.
356     let events = poll_ctx.wait_timeout(CHUNNEL_CONNECT_TIMEOUT).map_err(Error::PollWait)?;
357 
358     match events.iter_readable().next() {
359         Some(_) => {
360             let (vsock_stream, sockaddr) = vsock_listener.accept().map_err(Error::VsockAccept)?;
361 
362             if sockaddr.cid() != cid {
363                 Err(Error::IncorrectCid(sockaddr.cid()))
364             } else {
365                 Ok(ForwarderSession::new(tcp_stream.into(), vsock_stream.into()))
366             }
367         }
368         None => Err(Error::VsockAcceptTimeout),
369     }
370 }
371 
372 // TODO(b/340126051): Host can receive opened ports from the guest.
run_forwarder_host(cid: i32, jni_env: JNIEnv, jni_cb: JObject) -> Result<()>373 fn run_forwarder_host(cid: i32, jni_env: JNIEnv, jni_cb: JObject) -> Result<()> {
374     debug!("Starting forwarder_host");
375     let mut sessions = ForwarderSessions::new(cid, jni_env, jni_cb)?;
376     sessions.run()
377 }
378 
379 /// JNI function for running forwarder_host.
380 #[no_mangle]
Java_com_android_virtualization_terminal_DebianServiceImpl_runForwarderHost( env: JNIEnv, _class: JObject, cid: jint, callback: JObject, )381 pub extern "C" fn Java_com_android_virtualization_terminal_DebianServiceImpl_runForwarderHost(
382     env: JNIEnv,
383     _class: JObject,
384     cid: jint,
385     callback: JObject,
386 ) {
387     // Clear shutdown event FD before running forwarder host.
388     SHUTDOWN_EVT.write(1).expect("Failed to write shutdown event FD");
389     SHUTDOWN_EVT.read().expect("Failed to consume shutdown event FD");
390 
391     match run_forwarder_host(cid, env, callback) {
392         Ok(_) => {
393             info!("forwarder_host is terminated");
394         }
395         Err(e) => {
396             error!("Error on forwarder_host: {:?}", e);
397         }
398     }
399 }
400 
401 /// JNI function for terminating forwarder_host.
402 #[no_mangle]
Java_com_android_virtualization_terminal_DebianServiceImpl_terminateForwarderHost( _env: JNIEnv, _class: JObject, )403 pub extern "C" fn Java_com_android_virtualization_terminal_DebianServiceImpl_terminateForwarderHost(
404     _env: JNIEnv,
405     _class: JObject,
406 ) {
407     SHUTDOWN_EVT.write(1).expect("Failed to write shutdown event FD");
408 }
409 
410 /// JNI function for updating listening ports.
411 #[no_mangle]
Java_com_android_virtualization_terminal_DebianServiceImpl_updateListeningPorts( env: JNIEnv, _class: JObject, ports: JIntArray, )412 pub extern "C" fn Java_com_android_virtualization_terminal_DebianServiceImpl_updateListeningPorts(
413     env: JNIEnv,
414     _class: JObject,
415     ports: JIntArray,
416 ) {
417     let length = env.get_array_length(&ports).expect("Failed to get length of port array");
418     let mut buf = vec![0; length as usize];
419     env.get_int_array_region(ports, 0, &mut buf).expect("Failed to get port array");
420 
421     let mut update_queue = UPDATE_QUEUE.lock().unwrap();
422     update_queue.clear();
423     for port in buf {
424         update_queue.push_back(port.try_into().expect("Failed to add port into update queue"));
425     }
426     UPDATE_EVT.write(1).expect("failed to write update eventfd");
427 }
428