1 use super::{
2 rejection::{FailedToResolveHost, HostRejection},
3 FromRequestParts,
4 };
5 use async_trait::async_trait;
6 use http::{
7 header::{HeaderMap, FORWARDED},
8 request::Parts,
9 };
10
11 const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
12
13 /// Extractor that resolves the hostname of the request.
14 ///
15 /// Hostname is resolved through the following, in order:
16 /// - `Forwarded` header
17 /// - `X-Forwarded-Host` header
18 /// - `Host` header
19 /// - request target / URI
20 ///
21 /// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make
22 /// sure to validate them to avoid security issues.
23 #[derive(Debug, Clone)]
24 pub struct Host(pub String);
25
26 #[async_trait]
27 impl<S> FromRequestParts<S> for Host
28 where
29 S: Send + Sync,
30 {
31 type Rejection = HostRejection;
32
from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection>33 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
34 if let Some(host) = parse_forwarded(&parts.headers) {
35 return Ok(Host(host.to_owned()));
36 }
37
38 if let Some(host) = parts
39 .headers
40 .get(X_FORWARDED_HOST_HEADER_KEY)
41 .and_then(|host| host.to_str().ok())
42 {
43 return Ok(Host(host.to_owned()));
44 }
45
46 if let Some(host) = parts
47 .headers
48 .get(http::header::HOST)
49 .and_then(|host| host.to_str().ok())
50 {
51 return Ok(Host(host.to_owned()));
52 }
53
54 if let Some(host) = parts.uri.host() {
55 return Ok(Host(host.to_owned()));
56 }
57
58 Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
59 }
60 }
61
62 #[allow(warnings)]
parse_forwarded(headers: &HeaderMap) -> Option<&str>63 fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
64 // if there are multiple `Forwarded` `HeaderMap::get` will return the first one
65 let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
66
67 // get the first set of values
68 let first_value = forwarded_values.split(',').nth(0)?;
69
70 // find the value of the `host` field
71 first_value.split(';').find_map(|pair| {
72 let (key, value) = pair.split_once('=')?;
73 key.trim()
74 .eq_ignore_ascii_case("host")
75 .then(|| value.trim().trim_matches('"'))
76 })
77 }
78
79 #[cfg(test)]
80 mod tests {
81 use super::*;
82 use crate::{routing::get, test_helpers::TestClient, Router};
83 use http::header::HeaderName;
84
test_client() -> TestClient85 fn test_client() -> TestClient {
86 async fn host_as_body(Host(host): Host) -> String {
87 host
88 }
89
90 TestClient::new(Router::new().route("/", get(host_as_body)))
91 }
92
93 #[crate::test]
host_header()94 async fn host_header() {
95 let original_host = "some-domain:123";
96 let host = test_client()
97 .get("/")
98 .header(http::header::HOST, original_host)
99 .send()
100 .await
101 .text()
102 .await;
103 assert_eq!(host, original_host);
104 }
105
106 #[crate::test]
x_forwarded_host_header()107 async fn x_forwarded_host_header() {
108 let original_host = "some-domain:456";
109 let host = test_client()
110 .get("/")
111 .header(X_FORWARDED_HOST_HEADER_KEY, original_host)
112 .send()
113 .await
114 .text()
115 .await;
116 assert_eq!(host, original_host);
117 }
118
119 #[crate::test]
x_forwarded_host_precedence_over_host_header()120 async fn x_forwarded_host_precedence_over_host_header() {
121 let x_forwarded_host_header = "some-domain:456";
122 let host_header = "some-domain:123";
123 let host = test_client()
124 .get("/")
125 .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
126 .header(http::header::HOST, host_header)
127 .send()
128 .await
129 .text()
130 .await;
131 assert_eq!(host, x_forwarded_host_header);
132 }
133
134 #[crate::test]
uri_host()135 async fn uri_host() {
136 let host = test_client().get("/").send().await.text().await;
137 assert!(host.contains("127.0.0.1"));
138 }
139
140 #[test]
forwarded_parsing()141 fn forwarded_parsing() {
142 // the basic case
143 let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
144 let value = parse_forwarded(&headers).unwrap();
145 assert_eq!(value, "192.0.2.60");
146
147 // is case insensitive
148 let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
149 let value = parse_forwarded(&headers).unwrap();
150 assert_eq!(value, "192.0.2.60");
151
152 // ipv6
153 let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
154 let value = parse_forwarded(&headers).unwrap();
155 assert_eq!(value, "[2001:db8:cafe::17]:4711");
156
157 // multiple values in one header
158 let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
159 let value = parse_forwarded(&headers).unwrap();
160 assert_eq!(value, "192.0.2.60");
161
162 // multiple header values
163 let headers = header_map(&[
164 (FORWARDED, "host=192.0.2.60"),
165 (FORWARDED, "host=127.0.0.1"),
166 ]);
167 let value = parse_forwarded(&headers).unwrap();
168 assert_eq!(value, "192.0.2.60");
169 }
170
header_map(values: &[(HeaderName, &str)]) -> HeaderMap171 fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
172 let mut headers = HeaderMap::new();
173 for (key, value) in values {
174 headers.append(key, value.parse().unwrap());
175 }
176 headers
177 }
178 }
179