1 // Copyright 2024 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Copied from ChromiumOS with relicensing:
16 // src/platform2/vm_tools/chunnel/src/bin/chunnel.rs
17 
18 //! Guest-side stream socket forwarder
19 
20 use std::fmt;
21 use std::result;
22 
23 use clap::Parser;
24 use forwarder::forwarder::{ForwarderError, ForwarderSession};
25 use forwarder::stream::{StreamSocket, StreamSocketError};
26 use poll_token_derive::PollToken;
27 use vmm_sys_util::poll::{PollContext, PollToken};
28 
29 #[remain::sorted]
30 #[derive(Debug)]
31 enum Error {
32     ConnectSocket(StreamSocketError),
33     Forward(ForwarderError),
34     PollContextAdd(vmm_sys_util::errno::Error),
35     PollContextDelete(vmm_sys_util::errno::Error),
36     PollContextNew(vmm_sys_util::errno::Error),
37     PollWait(vmm_sys_util::errno::Error),
38 }
39 
40 type Result<T> = result::Result<T, Error>;
41 
42 impl fmt::Display for Error {
43     #[remain::check]
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result44     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45         use self::Error::*;
46 
47         #[remain::sorted]
48         match self {
49             ConnectSocket(e) => write!(f, "failed to connect socket: {}", e),
50             Forward(e) => write!(f, "failed to forward traffic: {}", e),
51             PollContextAdd(e) => write!(f, "failed to add fd to poll context: {}", e),
52             PollContextDelete(e) => write!(f, "failed to delete fd from poll context: {}", e),
53             PollContextNew(e) => write!(f, "failed to create poll context: {}", e),
54             PollWait(e) => write!(f, "failed to wait for poll: {}", e),
55         }
56     }
57 }
58 
run_forwarder(local_stream: StreamSocket, remote_stream: StreamSocket) -> Result<()>59 fn run_forwarder(local_stream: StreamSocket, remote_stream: StreamSocket) -> Result<()> {
60     #[derive(PollToken)]
61     enum Token {
62         LocalStreamReadable,
63         RemoteStreamReadable,
64     }
65     let poll_ctx: PollContext<Token> = PollContext::new().map_err(Error::PollContextNew)?;
66     poll_ctx.add(&local_stream, Token::LocalStreamReadable).map_err(Error::PollContextAdd)?;
67     poll_ctx.add(&remote_stream, Token::RemoteStreamReadable).map_err(Error::PollContextAdd)?;
68 
69     let mut forwarder = ForwarderSession::new(local_stream, remote_stream);
70 
71     loop {
72         let events = poll_ctx.wait().map_err(Error::PollWait)?;
73 
74         for event in events.iter_readable() {
75             match event.token() {
76                 Token::LocalStreamReadable => {
77                     let shutdown = forwarder.forward_from_local().map_err(Error::Forward)?;
78                     if shutdown {
79                         poll_ctx
80                             .delete(forwarder.local_stream())
81                             .map_err(Error::PollContextDelete)?;
82                     }
83                 }
84                 Token::RemoteStreamReadable => {
85                     let shutdown = forwarder.forward_from_remote().map_err(Error::Forward)?;
86                     if shutdown {
87                         poll_ctx
88                             .delete(forwarder.remote_stream())
89                             .map_err(Error::PollContextDelete)?;
90                     }
91                 }
92             }
93         }
94         if forwarder.is_shut_down() {
95             return Ok(());
96         }
97     }
98 }
99 
100 #[derive(Parser)]
101 /// Flags for running command
102 pub struct Args {
103     /// Local socket address
104     #[arg(long)]
105     #[arg(alias = "local")]
106     local_sockaddr: String,
107 
108     /// Remote socket address
109     #[arg(long)]
110     #[arg(alias = "remote")]
111     remote_sockaddr: String,
112 }
113 
114 // TODO(b/370897694): Support forwarding for datagram socket
main() -> Result<()>115 fn main() -> Result<()> {
116     let args = Args::parse();
117 
118     let local_stream = StreamSocket::connect(&args.local_sockaddr).map_err(Error::ConnectSocket)?;
119     let remote_stream =
120         StreamSocket::connect(&args.remote_sockaddr).map_err(Error::ConnectSocket)?;
121 
122     run_forwarder(local_stream, remote_stream)
123 }
124