1 use super::{rejection::*, FromRequestParts};
2 use async_trait::async_trait;
3 use http::{request::Parts, Uri};
4 use serde::de::DeserializeOwned;
5 
6 /// Extractor that deserializes query strings into some type.
7 ///
8 /// `T` is expected to implement [`serde::Deserialize`].
9 ///
10 /// # Example
11 ///
12 /// ```rust,no_run
13 /// use axum::{
14 ///     extract::Query,
15 ///     routing::get,
16 ///     Router,
17 /// };
18 /// use serde::Deserialize;
19 ///
20 /// #[derive(Deserialize)]
21 /// struct Pagination {
22 ///     page: usize,
23 ///     per_page: usize,
24 /// }
25 ///
26 /// // This will parse query strings like `?page=2&per_page=30` into `Pagination`
27 /// // structs.
28 /// async fn list_things(pagination: Query<Pagination>) {
29 ///     let pagination: Pagination = pagination.0;
30 ///
31 ///     // ...
32 /// }
33 ///
34 /// let app = Router::new().route("/list_things", get(list_things));
35 /// # async {
36 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
37 /// # };
38 /// ```
39 ///
40 /// If the query string cannot be parsed it will reject the request with a `400
41 /// Bad Request` response.
42 ///
43 /// For handling values being empty vs missing see the [query-params-with-empty-strings][example]
44 /// example.
45 ///
46 /// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs
47 #[cfg_attr(docsrs, doc(cfg(feature = "query")))]
48 #[derive(Debug, Clone, Copy, Default)]
49 pub struct Query<T>(pub T);
50 
51 #[async_trait]
52 impl<T, S> FromRequestParts<S> for Query<T>
53 where
54     T: DeserializeOwned,
55     S: Send + Sync,
56 {
57     type Rejection = QueryRejection;
58 
from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection>59     async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
60         Self::try_from_uri(&parts.uri)
61     }
62 }
63 
64 impl<T> Query<T>
65 where
66     T: DeserializeOwned,
67 {
68     /// Attempts to construct a [`Query`] from a reference to a [`Uri`].
69     ///
70     /// # Example
71     /// ```
72     /// use axum::extract::Query;
73     /// use http::Uri;
74     /// use serde::Deserialize;
75     ///
76     /// #[derive(Deserialize)]
77     /// struct ExampleParams {
78     ///     foo: String,
79     ///     bar: u32,
80     /// }
81     ///
82     /// let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap();
83     /// let result: Query<ExampleParams> = Query::try_from_uri(&uri).unwrap();
84     /// assert_eq!(result.foo, String::from("hello"));
85     /// assert_eq!(result.bar, 42);
86     /// ```
try_from_uri(value: &Uri) -> Result<Self, QueryRejection>87     pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> {
88         let query = value.query().unwrap_or_default();
89         let params =
90             serde_urlencoded::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?;
91         Ok(Query(params))
92     }
93 }
94 
95 axum_core::__impl_deref!(Query);
96 
97 #[cfg(test)]
98 mod tests {
99     use crate::{routing::get, test_helpers::TestClient, Router};
100 
101     use super::*;
102     use axum_core::extract::FromRequest;
103     use http::{Request, StatusCode};
104     use serde::Deserialize;
105     use std::fmt::Debug;
106 
check<T>(uri: impl AsRef<str>, value: T) where T: DeserializeOwned + PartialEq + Debug,107     async fn check<T>(uri: impl AsRef<str>, value: T)
108     where
109         T: DeserializeOwned + PartialEq + Debug,
110     {
111         let req = Request::builder().uri(uri.as_ref()).body(()).unwrap();
112         assert_eq!(Query::<T>::from_request(req, &()).await.unwrap().0, value);
113     }
114 
115     #[crate::test]
test_query()116     async fn test_query() {
117         #[derive(Debug, PartialEq, Deserialize)]
118         struct Pagination {
119             size: Option<u64>,
120             page: Option<u64>,
121         }
122 
123         check(
124             "http://example.com/test",
125             Pagination {
126                 size: None,
127                 page: None,
128             },
129         )
130         .await;
131 
132         check(
133             "http://example.com/test?size=10",
134             Pagination {
135                 size: Some(10),
136                 page: None,
137             },
138         )
139         .await;
140 
141         check(
142             "http://example.com/test?size=10&page=20",
143             Pagination {
144                 size: Some(10),
145                 page: Some(20),
146             },
147         )
148         .await;
149     }
150 
151     #[crate::test]
correct_rejection_status_code()152     async fn correct_rejection_status_code() {
153         #[derive(Deserialize)]
154         #[allow(dead_code)]
155         struct Params {
156             n: i32,
157         }
158 
159         async fn handler(_: Query<Params>) {}
160 
161         let app = Router::new().route("/", get(handler));
162         let client = TestClient::new(app);
163 
164         let res = client.get("/?n=hi").send().await;
165         assert_eq!(res.status(), StatusCode::BAD_REQUEST);
166     }
167 
168     #[test]
test_try_from_uri()169     fn test_try_from_uri() {
170         #[derive(Deserialize)]
171         struct TestQueryParams {
172             foo: String,
173             bar: u32,
174         }
175         let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap();
176         let result: Query<TestQueryParams> = Query::try_from_uri(&uri).unwrap();
177         assert_eq!(result.foo, String::from("hello"));
178         assert_eq!(result.bar, 42);
179     }
180 
181     #[test]
test_try_from_uri_with_invalid_query()182     fn test_try_from_uri_with_invalid_query() {
183         #[derive(Deserialize)]
184         struct TestQueryParams {
185             _foo: String,
186             _bar: u32,
187         }
188         let uri: Uri = "http://example.com/path?foo=hello&bar=invalid"
189             .parse()
190             .unwrap();
191         let result: Result<Query<TestQueryParams>, _> = Query::try_from_uri(&uri);
192 
193         assert!(result.is_err());
194     }
195 }
196