1 // Copyright 2024 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Copied from ChromiumOS with relicensing:
16 // src/platform2/vm_tools/chunnel/src/bin/chunneld.rs
17
18 //! Host-side stream socket forwarder
19
20 use std::collections::btree_map::Entry as BTreeMapEntry;
21 use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque};
22 use std::fmt;
23 use std::io;
24 use std::net::{Ipv4Addr, Ipv6Addr, TcpListener};
25 use std::os::unix::io::AsRawFd;
26 use std::result;
27 use std::sync::{Arc, LazyLock, Mutex};
28 use std::time::Duration;
29
30 use forwarder::forwarder::ForwarderSession;
31 use jni::objects::{JIntArray, JObject, JValue};
32 use jni::sys::jint;
33 use jni::JNIEnv;
34 use log::{debug, error, info, warn};
35 use nix::sys::eventfd::EventFd;
36 use poll_token_derive::PollToken;
37 use vmm_sys_util::poll::{PollContext, PollToken};
38 use vsock::VsockListener;
39 use vsock::VMADDR_CID_ANY;
40
41 const CHUNNEL_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
42
43 const VMADDR_PORT_ANY: u32 = u32::MAX;
44
45 static SHUTDOWN_EVT: LazyLock<EventFd> =
46 LazyLock::new(|| EventFd::new().expect("Could not create shutdown eventfd"));
47
48 static UPDATE_EVT: LazyLock<EventFd> =
49 LazyLock::new(|| EventFd::new().expect("Could not create update eventfd"));
50
51 static UPDATE_QUEUE: LazyLock<Arc<Mutex<VecDeque<u16>>>> =
52 LazyLock::new(|| Arc::new(Mutex::new(VecDeque::new())));
53
54 #[remain::sorted]
55 #[derive(Debug)]
56 enum Error {
57 BindVsock(io::Error),
58 IncorrectCid(u32),
59 LaunchForwarderGuest(jni::errors::Error),
60 NoListenerForPort(u16),
61 NoSessionForTag(SessionTag),
62 PollContextAdd(vmm_sys_util::errno::Error),
63 PollContextDelete(vmm_sys_util::errno::Error),
64 PollContextNew(vmm_sys_util::errno::Error),
65 PollWait(vmm_sys_util::errno::Error),
66 SetVsockNonblocking(io::Error),
67 TcpAccept(io::Error),
68 TcpListenerPort(io::Error),
69 UpdateEventRead(nix::Error),
70 VsockAccept(io::Error),
71 VsockAcceptTimeout,
72 VsockListenerPort(io::Error),
73 }
74
75 type Result<T> = result::Result<T, Error>;
76
77 impl fmt::Display for Error {
78 #[remain::check]
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result79 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
80 use self::Error::*;
81
82 #[remain::sorted]
83 match self {
84 BindVsock(e) => write!(f, "failed to bind vsock: {}", e),
85 IncorrectCid(cid) => write!(f, "chunnel connection from unexpected cid {}", cid),
86 LaunchForwarderGuest(e) => write!(f, "failed to launch forwarder_guest {}", e),
87 NoListenerForPort(port) => write!(f, "could not find listener for port: {}", port),
88 NoSessionForTag(tag) => write!(f, "could not find session for tag: {:x}", tag),
89 PollContextAdd(e) => write!(f, "failed to add fd to poll context: {}", e),
90 PollContextDelete(e) => write!(f, "failed to delete fd from poll context: {}", e),
91 PollContextNew(e) => write!(f, "failed to create poll context: {}", e),
92 PollWait(e) => write!(f, "failed to wait for poll: {}", e),
93 SetVsockNonblocking(e) => write!(f, "failed to set vsock to nonblocking: {}", e),
94 TcpAccept(e) => write!(f, "failed to accept tcp: {}", e),
95 TcpListenerPort(e) => {
96 write!(f, "failed to read local sockaddr for tcp listener: {}", e)
97 }
98 UpdateEventRead(e) => write!(f, "failed to read update eventfd: {}", e),
99 VsockAccept(e) => write!(f, "failed to accept vsock: {}", e),
100 VsockAcceptTimeout => write!(f, "timed out waiting for vsock connection"),
101 VsockListenerPort(e) => write!(f, "failed to get vsock listener port: {}", e),
102 }
103 }
104 }
105
106 /// A tag that uniquely identifies a particular forwarding session. This has arbitrarily been
107 /// chosen as the fd of the local (TCP) socket.
108 type SessionTag = u32;
109
110 /// Implements PollToken for chunneld's main poll loop.
111 #[derive(Clone, Copy, PollToken)]
112 enum Token {
113 Shutdown,
114 UpdatePorts,
115 Ipv4Listener(u16),
116 Ipv6Listener(u16),
117 LocalSocket(SessionTag),
118 RemoteSocket(SessionTag),
119 }
120
121 /// PortListeners includes all listeners (IPv4 and IPv6) for a given port, and the target
122 /// container.
123 struct PortListeners {
124 tcp4_listener: TcpListener,
125 tcp6_listener: TcpListener,
126 }
127
128 /// SocketFamily specifies whether a socket uses IPv4 or IPv6.
129 enum SocketFamily {
130 Ipv4,
131 Ipv6,
132 }
133
134 /// ForwarderSessions encapsulates all forwarding state for chunneld.
135 struct ForwarderSessions<'a> {
136 listening_ports: BTreeMap<u16, PortListeners>,
137 tcp4_forwarders: HashMap<SessionTag, ForwarderSession>,
138 cid: u32,
139 jni_env: JNIEnv<'a>,
140 jni_cb: JObject<'a>,
141 }
142
143 impl<'a> ForwarderSessions<'a> {
144 /// Creates a new instance of ForwarderSessions.
new(cid: i32, jni_env: JNIEnv<'a>, jni_cb: JObject<'a>) -> Result<Self>145 fn new(cid: i32, jni_env: JNIEnv<'a>, jni_cb: JObject<'a>) -> Result<Self> {
146 Ok(ForwarderSessions {
147 listening_ports: BTreeMap::new(),
148 tcp4_forwarders: HashMap::new(),
149 cid: cid as u32,
150 jni_env,
151 jni_cb,
152 })
153 }
154
155 /// Adds or removes listeners based on the latest listening ports from the D-Bus thread.
process_update_queue(&mut self, poll_ctx: &PollContext<Token>) -> Result<()>156 fn process_update_queue(&mut self, poll_ctx: &PollContext<Token>) -> Result<()> {
157 // Unwrap of LockResult is customary.
158 let mut update_queue = UPDATE_QUEUE.lock().unwrap();
159 let mut active_ports: BTreeSet<u16> = BTreeSet::new();
160
161 // Add any new listeners first.
162 while let Some(port) = update_queue.pop_front() {
163 // Ignore privileged ports.
164 if port < 1024 {
165 continue;
166 }
167 if let BTreeMapEntry::Vacant(o) = self.listening_ports.entry(port) {
168 // Failing to bind a port is not fatal, but we should log it.
169 // Both IPv4 and IPv6 localhost must be bound since the host may resolve
170 // "localhost" to either.
171 let tcp4_listener = match TcpListener::bind((Ipv4Addr::LOCALHOST, port)) {
172 Ok(listener) => listener,
173 Err(e) => {
174 warn!("failed to bind TCPv4 port: {}", e);
175 continue;
176 }
177 };
178 let tcp6_listener = match TcpListener::bind((Ipv6Addr::LOCALHOST, port)) {
179 Ok(listener) => listener,
180 Err(e) => {
181 warn!("failed to bind TCPv6 port: {}", e);
182 continue;
183 }
184 };
185 poll_ctx
186 .add(&tcp4_listener, Token::Ipv4Listener(port))
187 .map_err(Error::PollContextAdd)?;
188 poll_ctx
189 .add(&tcp6_listener, Token::Ipv6Listener(port))
190 .map_err(Error::PollContextAdd)?;
191 o.insert(PortListeners { tcp4_listener, tcp6_listener });
192 }
193 active_ports.insert(port);
194 }
195
196 // Iterate over the existing listeners; if the port is no longer in the
197 // listener list, remove it.
198 let old_ports: Vec<u16> = self.listening_ports.keys().cloned().collect();
199 for port in old_ports.iter() {
200 if !active_ports.contains(port) {
201 // Remove the PortListeners struct first - on error we want to drop it and the
202 // fds it contains.
203 let _listening_port = self.listening_ports.remove(port);
204 }
205 }
206
207 // Consume the eventfd.
208 UPDATE_EVT.read().map_err(Error::UpdateEventRead)?;
209
210 Ok(())
211 }
212
accept_connection( &mut self, poll_ctx: &PollContext<Token>, port: u16, sock_family: SocketFamily, ) -> Result<()>213 fn accept_connection(
214 &mut self,
215 poll_ctx: &PollContext<Token>,
216 port: u16,
217 sock_family: SocketFamily,
218 ) -> Result<()> {
219 let port_listeners =
220 self.listening_ports.get(&port).ok_or(Error::NoListenerForPort(port))?;
221
222 let listener = match sock_family {
223 SocketFamily::Ipv4 => &port_listeners.tcp4_listener,
224 SocketFamily::Ipv6 => &port_listeners.tcp6_listener,
225 };
226
227 // This session should be dropped if any of the PollContext setup fails. Since the only
228 // extant fds for the underlying sockets will be closed, they will be unregistered from
229 // epoll set automatically.
230 let session =
231 create_forwarder_session(listener, self.cid, &mut self.jni_env, &self.jni_cb)?;
232
233 let tag = session.local_stream().as_raw_fd() as u32;
234
235 poll_ctx
236 .add(session.local_stream(), Token::LocalSocket(tag))
237 .map_err(Error::PollContextAdd)?;
238 poll_ctx
239 .add(session.remote_stream(), Token::RemoteSocket(tag))
240 .map_err(Error::PollContextAdd)?;
241
242 self.tcp4_forwarders.insert(tag, session);
243
244 Ok(())
245 }
246
forward_from_local(&mut self, poll_ctx: &PollContext<Token>, tag: SessionTag) -> Result<()>247 fn forward_from_local(&mut self, poll_ctx: &PollContext<Token>, tag: SessionTag) -> Result<()> {
248 let session = self.tcp4_forwarders.get_mut(&tag).ok_or(Error::NoSessionForTag(tag))?;
249 let shutdown = session.forward_from_local().unwrap_or(true);
250 if shutdown {
251 poll_ctx.delete(session.local_stream()).map_err(Error::PollContextDelete)?;
252 if session.is_shut_down() {
253 self.tcp4_forwarders.remove(&tag);
254 }
255 }
256
257 Ok(())
258 }
259
forward_from_remote( &mut self, poll_ctx: &PollContext<Token>, tag: SessionTag, ) -> Result<()>260 fn forward_from_remote(
261 &mut self,
262 poll_ctx: &PollContext<Token>,
263 tag: SessionTag,
264 ) -> Result<()> {
265 let session = self.tcp4_forwarders.get_mut(&tag).ok_or(Error::NoSessionForTag(tag))?;
266 let shutdown = session.forward_from_remote().unwrap_or(true);
267 if shutdown {
268 poll_ctx.delete(session.remote_stream()).map_err(Error::PollContextDelete)?;
269 if session.is_shut_down() {
270 self.tcp4_forwarders.remove(&tag);
271 }
272 }
273
274 Ok(())
275 }
276
run(&mut self) -> Result<()>277 fn run(&mut self) -> Result<()> {
278 let poll_ctx: PollContext<Token> = PollContext::new().map_err(Error::PollContextNew)?;
279 poll_ctx.add(&*UPDATE_EVT, Token::UpdatePorts).map_err(Error::PollContextAdd)?;
280 poll_ctx.add(&*SHUTDOWN_EVT, Token::Shutdown).map_err(Error::PollContextAdd)?;
281
282 loop {
283 let events = poll_ctx.wait().map_err(Error::PollWait)?;
284
285 for event in events.iter_readable() {
286 match event.token() {
287 Token::Shutdown => {
288 return Ok(());
289 }
290 Token::UpdatePorts => {
291 if let Err(e) = self.process_update_queue(&poll_ctx) {
292 error!("error updating listening ports: {}", e);
293 }
294 }
295 Token::Ipv4Listener(port) => {
296 if let Err(e) = self.accept_connection(&poll_ctx, port, SocketFamily::Ipv4)
297 {
298 error!("error accepting connection: {}", e);
299 }
300 }
301 Token::Ipv6Listener(port) => {
302 if let Err(e) = self.accept_connection(&poll_ctx, port, SocketFamily::Ipv6)
303 {
304 error!("error accepting connection: {}", e);
305 }
306 }
307 Token::LocalSocket(tag) => {
308 if let Err(e) = self.forward_from_local(&poll_ctx, tag) {
309 error!("error forwarding local traffic: {}", e);
310 }
311 }
312 Token::RemoteSocket(tag) => {
313 if let Err(e) = self.forward_from_remote(&poll_ctx, tag) {
314 error!("error forwarding remote traffic: {}", e);
315 }
316 }
317 }
318 }
319 }
320 }
321 }
322
323 /// Creates a forwarder session from a `listener` that has a pending connection to accept.
create_forwarder_session( listener: &TcpListener, cid: u32, jni_env: &mut JNIEnv, jni_cb: &JObject, ) -> Result<ForwarderSession>324 fn create_forwarder_session(
325 listener: &TcpListener,
326 cid: u32,
327 jni_env: &mut JNIEnv,
328 jni_cb: &JObject,
329 ) -> Result<ForwarderSession> {
330 let (tcp_stream, _) = listener.accept().map_err(Error::TcpAccept)?;
331 // Bind a vsock port, tell the guest to connect, and accept the connection.
332 let vsock_listener = VsockListener::bind_with_cid_port(VMADDR_CID_ANY, VMADDR_PORT_ANY)
333 .map_err(Error::BindVsock)?;
334 vsock_listener.set_nonblocking(true).map_err(Error::SetVsockNonblocking)?;
335
336 let tcp4_port = listener.local_addr().map_err(Error::TcpListenerPort)?.port();
337 let vsock_port = vsock_listener.local_addr().map_err(Error::VsockListenerPort)?.port();
338 jni_env
339 .call_method(
340 jni_cb,
341 "onForwardingRequestReceived",
342 "(II)V",
343 &[JValue::Int(tcp4_port.into()), JValue::Int(vsock_port as i32)],
344 )
345 .map_err(Error::LaunchForwarderGuest)?;
346
347 #[derive(PollToken)]
348 enum Token {
349 VsockAccept,
350 }
351
352 let poll_ctx: PollContext<Token> = PollContext::new().map_err(Error::PollContextNew)?;
353 poll_ctx.add(&vsock_listener, Token::VsockAccept).map_err(Error::PollContextAdd)?;
354
355 // Wait a few seconds for the guest to connect.
356 let events = poll_ctx.wait_timeout(CHUNNEL_CONNECT_TIMEOUT).map_err(Error::PollWait)?;
357
358 match events.iter_readable().next() {
359 Some(_) => {
360 let (vsock_stream, sockaddr) = vsock_listener.accept().map_err(Error::VsockAccept)?;
361
362 if sockaddr.cid() != cid {
363 Err(Error::IncorrectCid(sockaddr.cid()))
364 } else {
365 Ok(ForwarderSession::new(tcp_stream.into(), vsock_stream.into()))
366 }
367 }
368 None => Err(Error::VsockAcceptTimeout),
369 }
370 }
371
372 // TODO(b/340126051): Host can receive opened ports from the guest.
run_forwarder_host(cid: i32, jni_env: JNIEnv, jni_cb: JObject) -> Result<()>373 fn run_forwarder_host(cid: i32, jni_env: JNIEnv, jni_cb: JObject) -> Result<()> {
374 debug!("Starting forwarder_host");
375 let mut sessions = ForwarderSessions::new(cid, jni_env, jni_cb)?;
376 sessions.run()
377 }
378
379 /// JNI function for running forwarder_host.
380 #[no_mangle]
Java_com_android_virtualization_terminal_DebianServiceImpl_runForwarderHost( env: JNIEnv, _class: JObject, cid: jint, callback: JObject, )381 pub extern "C" fn Java_com_android_virtualization_terminal_DebianServiceImpl_runForwarderHost(
382 env: JNIEnv,
383 _class: JObject,
384 cid: jint,
385 callback: JObject,
386 ) {
387 // Clear shutdown event FD before running forwarder host.
388 SHUTDOWN_EVT.write(1).expect("Failed to write shutdown event FD");
389 SHUTDOWN_EVT.read().expect("Failed to consume shutdown event FD");
390
391 match run_forwarder_host(cid, env, callback) {
392 Ok(_) => {
393 info!("forwarder_host is terminated");
394 }
395 Err(e) => {
396 error!("Error on forwarder_host: {:?}", e);
397 }
398 }
399 }
400
401 /// JNI function for terminating forwarder_host.
402 #[no_mangle]
Java_com_android_virtualization_terminal_DebianServiceImpl_terminateForwarderHost( _env: JNIEnv, _class: JObject, )403 pub extern "C" fn Java_com_android_virtualization_terminal_DebianServiceImpl_terminateForwarderHost(
404 _env: JNIEnv,
405 _class: JObject,
406 ) {
407 SHUTDOWN_EVT.write(1).expect("Failed to write shutdown event FD");
408 }
409
410 /// JNI function for updating listening ports.
411 #[no_mangle]
Java_com_android_virtualization_terminal_DebianServiceImpl_updateListeningPorts( env: JNIEnv, _class: JObject, ports: JIntArray, )412 pub extern "C" fn Java_com_android_virtualization_terminal_DebianServiceImpl_updateListeningPorts(
413 env: JNIEnv,
414 _class: JObject,
415 ports: JIntArray,
416 ) {
417 let length = env.get_array_length(&ports).expect("Failed to get length of port array");
418 let mut buf = vec![0; length as usize];
419 env.get_int_array_region(ports, 0, &mut buf).expect("Failed to get port array");
420
421 let mut update_queue = UPDATE_QUEUE.lock().unwrap();
422 update_queue.clear();
423 for port in buf {
424 update_queue.push_back(port.try_into().expect("Failed to add port into update queue"));
425 }
426 UPDATE_EVT.write(1).expect("failed to write update eventfd");
427 }
428