1 use super::{
2     rejection::{FailedToResolveHost, HostRejection},
3     FromRequestParts,
4 };
5 use async_trait::async_trait;
6 use http::{
7     header::{HeaderMap, FORWARDED},
8     request::Parts,
9 };
10 
11 const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
12 
13 /// Extractor that resolves the hostname of the request.
14 ///
15 /// Hostname is resolved through the following, in order:
16 /// - `Forwarded` header
17 /// - `X-Forwarded-Host` header
18 /// - `Host` header
19 /// - request target / URI
20 ///
21 /// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make
22 /// sure to validate them to avoid security issues.
23 #[derive(Debug, Clone)]
24 pub struct Host(pub String);
25 
26 #[async_trait]
27 impl<S> FromRequestParts<S> for Host
28 where
29     S: Send + Sync,
30 {
31     type Rejection = HostRejection;
32 
from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection>33     async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
34         if let Some(host) = parse_forwarded(&parts.headers) {
35             return Ok(Host(host.to_owned()));
36         }
37 
38         if let Some(host) = parts
39             .headers
40             .get(X_FORWARDED_HOST_HEADER_KEY)
41             .and_then(|host| host.to_str().ok())
42         {
43             return Ok(Host(host.to_owned()));
44         }
45 
46         if let Some(host) = parts
47             .headers
48             .get(http::header::HOST)
49             .and_then(|host| host.to_str().ok())
50         {
51             return Ok(Host(host.to_owned()));
52         }
53 
54         if let Some(host) = parts.uri.host() {
55             return Ok(Host(host.to_owned()));
56         }
57 
58         Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
59     }
60 }
61 
62 #[allow(warnings)]
parse_forwarded(headers: &HeaderMap) -> Option<&str>63 fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
64     // if there are multiple `Forwarded` `HeaderMap::get` will return the first one
65     let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
66 
67     // get the first set of values
68     let first_value = forwarded_values.split(',').nth(0)?;
69 
70     // find the value of the `host` field
71     first_value.split(';').find_map(|pair| {
72         let (key, value) = pair.split_once('=')?;
73         key.trim()
74             .eq_ignore_ascii_case("host")
75             .then(|| value.trim().trim_matches('"'))
76     })
77 }
78 
79 #[cfg(test)]
80 mod tests {
81     use super::*;
82     use crate::{routing::get, test_helpers::TestClient, Router};
83     use http::header::HeaderName;
84 
test_client() -> TestClient85     fn test_client() -> TestClient {
86         async fn host_as_body(Host(host): Host) -> String {
87             host
88         }
89 
90         TestClient::new(Router::new().route("/", get(host_as_body)))
91     }
92 
93     #[crate::test]
host_header()94     async fn host_header() {
95         let original_host = "some-domain:123";
96         let host = test_client()
97             .get("/")
98             .header(http::header::HOST, original_host)
99             .send()
100             .await
101             .text()
102             .await;
103         assert_eq!(host, original_host);
104     }
105 
106     #[crate::test]
x_forwarded_host_header()107     async fn x_forwarded_host_header() {
108         let original_host = "some-domain:456";
109         let host = test_client()
110             .get("/")
111             .header(X_FORWARDED_HOST_HEADER_KEY, original_host)
112             .send()
113             .await
114             .text()
115             .await;
116         assert_eq!(host, original_host);
117     }
118 
119     #[crate::test]
x_forwarded_host_precedence_over_host_header()120     async fn x_forwarded_host_precedence_over_host_header() {
121         let x_forwarded_host_header = "some-domain:456";
122         let host_header = "some-domain:123";
123         let host = test_client()
124             .get("/")
125             .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
126             .header(http::header::HOST, host_header)
127             .send()
128             .await
129             .text()
130             .await;
131         assert_eq!(host, x_forwarded_host_header);
132     }
133 
134     #[crate::test]
uri_host()135     async fn uri_host() {
136         let host = test_client().get("/").send().await.text().await;
137         assert!(host.contains("127.0.0.1"));
138     }
139 
140     #[test]
forwarded_parsing()141     fn forwarded_parsing() {
142         // the basic case
143         let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
144         let value = parse_forwarded(&headers).unwrap();
145         assert_eq!(value, "192.0.2.60");
146 
147         // is case insensitive
148         let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
149         let value = parse_forwarded(&headers).unwrap();
150         assert_eq!(value, "192.0.2.60");
151 
152         // ipv6
153         let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
154         let value = parse_forwarded(&headers).unwrap();
155         assert_eq!(value, "[2001:db8:cafe::17]:4711");
156 
157         // multiple values in one header
158         let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
159         let value = parse_forwarded(&headers).unwrap();
160         assert_eq!(value, "192.0.2.60");
161 
162         // multiple header values
163         let headers = header_map(&[
164             (FORWARDED, "host=192.0.2.60"),
165             (FORWARDED, "host=127.0.0.1"),
166         ]);
167         let value = parse_forwarded(&headers).unwrap();
168         assert_eq!(value, "192.0.2.60");
169     }
170 
header_map(values: &[(HeaderName, &str)]) -> HeaderMap171     fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
172         let mut headers = HeaderMap::new();
173         for (key, value) in values {
174             headers.append(key, value.parse().unwrap());
175         }
176         headers
177     }
178 }
179