1 //! Connection helper.
2 use std::io::{Read, Write};
3
4 use crate::{
5 client::{client_with_config, uri_mode, IntoClientRequest},
6 error::UrlError,
7 handshake::client::Response,
8 protocol::WebSocketConfig,
9 stream::MaybeTlsStream,
10 ClientHandshake, Error, HandshakeError, Result, WebSocket,
11 };
12
13 /// A connector that can be used when establishing connections, allowing to control whether
14 /// `native-tls` or `rustls` is used to create a TLS connection. Or TLS can be disabled with the
15 /// `Plain` variant.
16 #[non_exhaustive]
17 #[allow(missing_debug_implementations)]
18 pub enum Connector {
19 /// Plain (non-TLS) connector.
20 Plain,
21 /// `native-tls` TLS connector.
22 #[cfg(feature = "native-tls")]
23 NativeTls(native_tls_crate::TlsConnector),
24 /// `rustls` TLS connector.
25 #[cfg(feature = "__rustls-tls")]
26 Rustls(std::sync::Arc<rustls::ClientConfig>),
27 }
28
29 mod encryption {
30 #[cfg(feature = "native-tls")]
31 pub mod native_tls {
32 use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
33
34 use std::io::{Read, Write};
35
36 use crate::{
37 error::TlsError,
38 stream::{MaybeTlsStream, Mode},
39 Error, Result,
40 };
41
wrap_stream<S>( socket: S, domain: &str, mode: Mode, tls_connector: Option<TlsConnector>, ) -> Result<MaybeTlsStream<S>> where S: Read + Write,42 pub fn wrap_stream<S>(
43 socket: S,
44 domain: &str,
45 mode: Mode,
46 tls_connector: Option<TlsConnector>,
47 ) -> Result<MaybeTlsStream<S>>
48 where
49 S: Read + Write,
50 {
51 match mode {
52 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
53 Mode::Tls => {
54 let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
55 let connector = try_connector.map_err(TlsError::Native)?;
56 let connected = connector.connect(domain, socket);
57 match connected {
58 Err(e) => match e {
59 TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())),
60 TlsHandshakeError::WouldBlock(_) => {
61 panic!("Bug: TLS handshake not blocked")
62 }
63 },
64 Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
65 }
66 }
67 }
68 }
69 }
70
71 #[cfg(feature = "__rustls-tls")]
72 pub mod rustls {
73 use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
74 use rustls_pki_types::ServerName;
75
76 use std::{
77 convert::TryFrom,
78 io::{Read, Write},
79 sync::Arc,
80 };
81
82 use crate::{
83 error::TlsError,
84 stream::{MaybeTlsStream, Mode},
85 Result,
86 };
87
wrap_stream<S>( socket: S, domain: &str, mode: Mode, tls_connector: Option<Arc<ClientConfig>>, ) -> Result<MaybeTlsStream<S>> where S: Read + Write,88 pub fn wrap_stream<S>(
89 socket: S,
90 domain: &str,
91 mode: Mode,
92 tls_connector: Option<Arc<ClientConfig>>,
93 ) -> Result<MaybeTlsStream<S>>
94 where
95 S: Read + Write,
96 {
97 match mode {
98 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
99 Mode::Tls => {
100 let config = match tls_connector {
101 Some(config) => config,
102 None => {
103 #[allow(unused_mut)]
104 let mut root_store = RootCertStore::empty();
105
106 #[cfg(feature = "rustls-tls-native-roots")]
107 {
108 let native_certs = rustls_native_certs::load_native_certs()?;
109 let total_number = native_certs.len();
110 let (number_added, number_ignored) =
111 root_store.add_parsable_certificates(native_certs);
112 log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
113 }
114 #[cfg(feature = "rustls-tls-webpki-roots")]
115 {
116 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
117 }
118
119 Arc::new(
120 ClientConfig::builder()
121 .with_root_certificates(root_store)
122 .with_no_client_auth(),
123 )
124 }
125 };
126 let domain = ServerName::try_from(domain)
127 .map_err(|_| TlsError::InvalidDnsName)?
128 .to_owned();
129 let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
130 let stream = StreamOwned::new(client, socket);
131
132 Ok(MaybeTlsStream::Rustls(stream))
133 }
134 }
135 }
136 }
137
138 pub mod plain {
139 use std::io::{Read, Write};
140
141 use crate::{
142 error::UrlError,
143 stream::{MaybeTlsStream, Mode},
144 Error, Result,
145 };
146
wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>> where S: Read + Write,147 pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>>
148 where
149 S: Read + Write,
150 {
151 match mode {
152 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
153 Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
154 }
155 }
156 }
157 }
158
159 type TlsHandshakeError<S> = HandshakeError<ClientHandshake<MaybeTlsStream<S>>>;
160
161 /// Creates a WebSocket handshake from a request and a stream,
162 /// upgrading the stream to TLS if required.
client_tls<R, S>( request: R, stream: S, ) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>> where R: IntoClientRequest, S: Read + Write,163 pub fn client_tls<R, S>(
164 request: R,
165 stream: S,
166 ) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
167 where
168 R: IntoClientRequest,
169 S: Read + Write,
170 {
171 client_tls_with_config(request, stream, None, None)
172 }
173
174 /// The same as [`client_tls()`] but one can specify a websocket configuration,
175 /// and an optional connector. If no connector is specified, a default one will
176 /// be created.
177 ///
178 /// Please refer to [`client_tls()`] for more details.
client_tls_with_config<R, S>( request: R, stream: S, config: Option<WebSocketConfig>, connector: Option<Connector>, ) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>> where R: IntoClientRequest, S: Read + Write,179 pub fn client_tls_with_config<R, S>(
180 request: R,
181 stream: S,
182 config: Option<WebSocketConfig>,
183 connector: Option<Connector>,
184 ) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
185 where
186 R: IntoClientRequest,
187 S: Read + Write,
188 {
189 let request = request.into_client_request()?;
190
191 #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
192 let domain = match request.uri().host() {
193 Some(d) => Ok(d.to_string()),
194 None => Err(Error::Url(UrlError::NoHostName)),
195 }?;
196
197 let mode = uri_mode(request.uri())?;
198
199 let stream = match connector {
200 Some(conn) => match conn {
201 #[cfg(feature = "native-tls")]
202 Connector::NativeTls(conn) => {
203 self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn))
204 }
205 #[cfg(feature = "__rustls-tls")]
206 Connector::Rustls(conn) => {
207 self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn))
208 }
209 Connector::Plain => self::encryption::plain::wrap_stream(stream, mode),
210 },
211 None => {
212 #[cfg(feature = "native-tls")]
213 {
214 self::encryption::native_tls::wrap_stream(stream, &domain, mode, None)
215 }
216 #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
217 {
218 self::encryption::rustls::wrap_stream(stream, &domain, mode, None)
219 }
220 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
221 {
222 self::encryption::plain::wrap_stream(stream, mode)
223 }
224 }
225 }?;
226
227 client_with_config(request, stream, config)
228 }
229