1 // Copyright 2024 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Copied from ChromiumOS with relicensing:
16 // src/platform2/vm_tools/chunnel/src/bin/chunnel.rs
17
18 //! Guest-side stream socket forwarder
19
20 use std::fmt;
21 use std::result;
22
23 use clap::Parser;
24 use forwarder::forwarder::{ForwarderError, ForwarderSession};
25 use forwarder::stream::{StreamSocket, StreamSocketError};
26 use poll_token_derive::PollToken;
27 use vmm_sys_util::poll::{PollContext, PollToken};
28
29 #[remain::sorted]
30 #[derive(Debug)]
31 enum Error {
32 ConnectSocket(StreamSocketError),
33 Forward(ForwarderError),
34 PollContextAdd(vmm_sys_util::errno::Error),
35 PollContextDelete(vmm_sys_util::errno::Error),
36 PollContextNew(vmm_sys_util::errno::Error),
37 PollWait(vmm_sys_util::errno::Error),
38 }
39
40 type Result<T> = result::Result<T, Error>;
41
42 impl fmt::Display for Error {
43 #[remain::check]
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result44 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45 use self::Error::*;
46
47 #[remain::sorted]
48 match self {
49 ConnectSocket(e) => write!(f, "failed to connect socket: {}", e),
50 Forward(e) => write!(f, "failed to forward traffic: {}", e),
51 PollContextAdd(e) => write!(f, "failed to add fd to poll context: {}", e),
52 PollContextDelete(e) => write!(f, "failed to delete fd from poll context: {}", e),
53 PollContextNew(e) => write!(f, "failed to create poll context: {}", e),
54 PollWait(e) => write!(f, "failed to wait for poll: {}", e),
55 }
56 }
57 }
58
run_forwarder(local_stream: StreamSocket, remote_stream: StreamSocket) -> Result<()>59 fn run_forwarder(local_stream: StreamSocket, remote_stream: StreamSocket) -> Result<()> {
60 #[derive(PollToken)]
61 enum Token {
62 LocalStreamReadable,
63 RemoteStreamReadable,
64 }
65 let poll_ctx: PollContext<Token> = PollContext::new().map_err(Error::PollContextNew)?;
66 poll_ctx.add(&local_stream, Token::LocalStreamReadable).map_err(Error::PollContextAdd)?;
67 poll_ctx.add(&remote_stream, Token::RemoteStreamReadable).map_err(Error::PollContextAdd)?;
68
69 let mut forwarder = ForwarderSession::new(local_stream, remote_stream);
70
71 loop {
72 let events = poll_ctx.wait().map_err(Error::PollWait)?;
73
74 for event in events.iter_readable() {
75 match event.token() {
76 Token::LocalStreamReadable => {
77 let shutdown = forwarder.forward_from_local().map_err(Error::Forward)?;
78 if shutdown {
79 poll_ctx
80 .delete(forwarder.local_stream())
81 .map_err(Error::PollContextDelete)?;
82 }
83 }
84 Token::RemoteStreamReadable => {
85 let shutdown = forwarder.forward_from_remote().map_err(Error::Forward)?;
86 if shutdown {
87 poll_ctx
88 .delete(forwarder.remote_stream())
89 .map_err(Error::PollContextDelete)?;
90 }
91 }
92 }
93 }
94 if forwarder.is_shut_down() {
95 return Ok(());
96 }
97 }
98 }
99
100 #[derive(Parser)]
101 /// Flags for running command
102 pub struct Args {
103 /// Local socket address
104 #[arg(long)]
105 #[arg(alias = "local")]
106 local_sockaddr: String,
107
108 /// Remote socket address
109 #[arg(long)]
110 #[arg(alias = "remote")]
111 remote_sockaddr: String,
112 }
113
114 // TODO(b/370897694): Support forwarding for datagram socket
main() -> Result<()>115 fn main() -> Result<()> {
116 let args = Args::parse();
117
118 let local_stream = StreamSocket::connect(&args.local_sockaddr).map_err(Error::ConnectSocket)?;
119 let remote_stream =
120 StreamSocket::connect(&args.remote_sockaddr).map_err(Error::ConnectSocket)?;
121
122 run_forwarder(local_stream, remote_stream)
123 }
124