1 //! Connection helper.
2 use std::io::{Read, Write};
3 
4 use crate::{
5     client::{client_with_config, uri_mode, IntoClientRequest},
6     error::UrlError,
7     handshake::client::Response,
8     protocol::WebSocketConfig,
9     stream::MaybeTlsStream,
10     ClientHandshake, Error, HandshakeError, Result, WebSocket,
11 };
12 
13 /// A connector that can be used when establishing connections, allowing to control whether
14 /// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
15 /// `Plain` variant.
16 #[non_exhaustive]
17 #[allow(missing_debug_implementations)]
18 pub enum Connector {
19     /// Plain (non-TLS) connector.
20     Plain,
21     /// `native-tls` TLS connector.
22     #[cfg(feature = "native-tls")]
23     NativeTls(native_tls_crate::TlsConnector),
24     /// `rustls` TLS connector.
25     #[cfg(feature = "__rustls-tls")]
26     Rustls(std::sync::Arc<rustls::ClientConfig>),
27 }
28 
29 mod encryption {
30     #[cfg(feature = "native-tls")]
31     pub mod native_tls {
32         use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
33 
34         use std::io::{Read, Write};
35 
36         use crate::{
37             error::TlsError,
38             stream::{MaybeTlsStream, Mode},
39             Error, Result,
40         };
41 
wrap_stream<S>( socket: S, domain: &str, mode: Mode, tls_connector: Option<TlsConnector>, ) -> Result<MaybeTlsStream<S>> where S: Read + Write,42         pub fn wrap_stream<S>(
43             socket: S,
44             domain: &str,
45             mode: Mode,
46             tls_connector: Option<TlsConnector>,
47         ) -> Result<MaybeTlsStream<S>>
48         where
49             S: Read + Write,
50         {
51             match mode {
52                 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
53                 Mode::Tls => {
54                     let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
55                     let connector = try_connector.map_err(TlsError::Native)?;
56                     let connected = connector.connect(domain, socket);
57                     match connected {
58                         Err(e) => match e {
59                             TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())),
60                             TlsHandshakeError::WouldBlock(_) => {
61                                 panic!("Bug: TLS handshake not blocked")
62                             }
63                         },
64                         Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
65                     }
66                 }
67             }
68         }
69     }
70 
71     #[cfg(feature = "__rustls-tls")]
72     pub mod rustls {
73         use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
74         use rustls_pki_types::ServerName;
75 
76         use std::{
77             convert::TryFrom,
78             io::{Read, Write},
79             sync::Arc,
80         };
81 
82         use crate::{
83             error::TlsError,
84             stream::{MaybeTlsStream, Mode},
85             Result,
86         };
87 
wrap_stream<S>( socket: S, domain: &str, mode: Mode, tls_connector: Option<Arc<ClientConfig>>, ) -> Result<MaybeTlsStream<S>> where S: Read + Write,88         pub fn wrap_stream<S>(
89             socket: S,
90             domain: &str,
91             mode: Mode,
92             tls_connector: Option<Arc<ClientConfig>>,
93         ) -> Result<MaybeTlsStream<S>>
94         where
95             S: Read + Write,
96         {
97             match mode {
98                 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
99                 Mode::Tls => {
100                     let config = match tls_connector {
101                         Some(config) => config,
102                         None => {
103                             #[allow(unused_mut)]
104                             let mut root_store = RootCertStore::empty();
105 
106                             #[cfg(feature = "rustls-tls-native-roots")]
107                             {
108                                 let native_certs = rustls_native_certs::load_native_certs()?;
109                                 let total_number = native_certs.len();
110                                 let (number_added, number_ignored) =
111                                     root_store.add_parsable_certificates(native_certs);
112                                 log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
113                             }
114                             #[cfg(feature = "rustls-tls-webpki-roots")]
115                             {
116                                 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
117                             }
118 
119                             Arc::new(
120                                 ClientConfig::builder()
121                                     .with_root_certificates(root_store)
122                                     .with_no_client_auth(),
123                             )
124                         }
125                     };
126                     let domain = ServerName::try_from(domain)
127                         .map_err(|_| TlsError::InvalidDnsName)?
128                         .to_owned();
129                     let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
130                     let stream = StreamOwned::new(client, socket);
131 
132                     Ok(MaybeTlsStream::Rustls(stream))
133                 }
134             }
135         }
136     }
137 
138     pub mod plain {
139         use std::io::{Read, Write};
140 
141         use crate::{
142             error::UrlError,
143             stream::{MaybeTlsStream, Mode},
144             Error, Result,
145         };
146 
wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>> where S: Read + Write,147         pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>>
148         where
149             S: Read + Write,
150         {
151             match mode {
152                 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
153                 Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
154             }
155         }
156     }
157 }
158 
159 type TlsHandshakeError<S> = HandshakeError<ClientHandshake<MaybeTlsStream<S>>>;
160 
161 /// Creates a WebSocket handshake from a request and a stream,
162 /// upgrading the stream to TLS if required.
client_tls<R, S>( request: R, stream: S, ) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>> where R: IntoClientRequest, S: Read + Write,163 pub fn client_tls<R, S>(
164     request: R,
165     stream: S,
166 ) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
167 where
168     R: IntoClientRequest,
169     S: Read + Write,
170 {
171     client_tls_with_config(request, stream, None, None)
172 }
173 
174 /// The same as [`client_tls()`] but one can specify a websocket configuration,
175 /// and an optional connector. If no connector is specified, a default one will
176 /// be created.
177 ///
178 /// Please refer to [`client_tls()`] for more details.
client_tls_with_config<R, S>( request: R, stream: S, config: Option<WebSocketConfig>, connector: Option<Connector>, ) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>> where R: IntoClientRequest, S: Read + Write,179 pub fn client_tls_with_config<R, S>(
180     request: R,
181     stream: S,
182     config: Option<WebSocketConfig>,
183     connector: Option<Connector>,
184 ) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
185 where
186     R: IntoClientRequest,
187     S: Read + Write,
188 {
189     let request = request.into_client_request()?;
190 
191     #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
192     let domain = match request.uri().host() {
193         Some(d) => Ok(d.to_string()),
194         None => Err(Error::Url(UrlError::NoHostName)),
195     }?;
196 
197     let mode = uri_mode(request.uri())?;
198 
199     let stream = match connector {
200         Some(conn) => match conn {
201             #[cfg(feature = "native-tls")]
202             Connector::NativeTls(conn) => {
203                 self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn))
204             }
205             #[cfg(feature = "__rustls-tls")]
206             Connector::Rustls(conn) => {
207                 self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn))
208             }
209             Connector::Plain => self::encryption::plain::wrap_stream(stream, mode),
210         },
211         None => {
212             #[cfg(feature = "native-tls")]
213             {
214                 self::encryption::native_tls::wrap_stream(stream, &domain, mode, None)
215             }
216             #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
217             {
218                 self::encryption::rustls::wrap_stream(stream, &domain, mode, None)
219             }
220             #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
221             {
222                 self::encryption::plain::wrap_stream(stream, mode)
223             }
224         }
225     }?;
226 
227     client_with_config(request, stream, config)
228 }
229