1 use super::{rejection::*, FromRequestParts}; 2 use async_trait::async_trait; 3 use http::{request::Parts, Uri}; 4 use serde::de::DeserializeOwned; 5 6 /// Extractor that deserializes query strings into some type. 7 /// 8 /// `T` is expected to implement [`serde::Deserialize`]. 9 /// 10 /// # Example 11 /// 12 /// ```rust,no_run 13 /// use axum::{ 14 /// extract::Query, 15 /// routing::get, 16 /// Router, 17 /// }; 18 /// use serde::Deserialize; 19 /// 20 /// #[derive(Deserialize)] 21 /// struct Pagination { 22 /// page: usize, 23 /// per_page: usize, 24 /// } 25 /// 26 /// // This will parse query strings like `?page=2&per_page=30` into `Pagination` 27 /// // structs. 28 /// async fn list_things(pagination: Query<Pagination>) { 29 /// let pagination: Pagination = pagination.0; 30 /// 31 /// // ... 32 /// } 33 /// 34 /// let app = Router::new().route("/list_things", get(list_things)); 35 /// # async { 36 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); 37 /// # }; 38 /// ``` 39 /// 40 /// If the query string cannot be parsed it will reject the request with a `400 41 /// Bad Request` response. 42 /// 43 /// For handling values being empty vs missing see the [query-params-with-empty-strings][example] 44 /// example. 45 /// 46 /// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs 47 #[cfg_attr(docsrs, doc(cfg(feature = "query")))] 48 #[derive(Debug, Clone, Copy, Default)] 49 pub struct Query<T>(pub T); 50 51 #[async_trait] 52 impl<T, S> FromRequestParts<S> for Query<T> 53 where 54 T: DeserializeOwned, 55 S: Send + Sync, 56 { 57 type Rejection = QueryRejection; 58 from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection>59 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { 60 Self::try_from_uri(&parts.uri) 61 } 62 } 63 64 impl<T> Query<T> 65 where 66 T: DeserializeOwned, 67 { 68 /// Attempts to construct a [`Query`] from a reference to a [`Uri`]. 69 /// 70 /// # Example 71 /// ``` 72 /// use axum::extract::Query; 73 /// use http::Uri; 74 /// use serde::Deserialize; 75 /// 76 /// #[derive(Deserialize)] 77 /// struct ExampleParams { 78 /// foo: String, 79 /// bar: u32, 80 /// } 81 /// 82 /// let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap(); 83 /// let result: Query<ExampleParams> = Query::try_from_uri(&uri).unwrap(); 84 /// assert_eq!(result.foo, String::from("hello")); 85 /// assert_eq!(result.bar, 42); 86 /// ``` try_from_uri(value: &Uri) -> Result<Self, QueryRejection>87 pub fn try_from_uri(value: &Uri) -> Result<Self, QueryRejection> { 88 let query = value.query().unwrap_or_default(); 89 let params = 90 serde_urlencoded::from_str(query).map_err(FailedToDeserializeQueryString::from_err)?; 91 Ok(Query(params)) 92 } 93 } 94 95 axum_core::__impl_deref!(Query); 96 97 #[cfg(test)] 98 mod tests { 99 use crate::{routing::get, test_helpers::TestClient, Router}; 100 101 use super::*; 102 use axum_core::extract::FromRequest; 103 use http::{Request, StatusCode}; 104 use serde::Deserialize; 105 use std::fmt::Debug; 106 check<T>(uri: impl AsRef<str>, value: T) where T: DeserializeOwned + PartialEq + Debug,107 async fn check<T>(uri: impl AsRef<str>, value: T) 108 where 109 T: DeserializeOwned + PartialEq + Debug, 110 { 111 let req = Request::builder().uri(uri.as_ref()).body(()).unwrap(); 112 assert_eq!(Query::<T>::from_request(req, &()).await.unwrap().0, value); 113 } 114 115 #[crate::test] test_query()116 async fn test_query() { 117 #[derive(Debug, PartialEq, Deserialize)] 118 struct Pagination { 119 size: Option<u64>, 120 page: Option<u64>, 121 } 122 123 check( 124 "http://example.com/test", 125 Pagination { 126 size: None, 127 page: None, 128 }, 129 ) 130 .await; 131 132 check( 133 "http://example.com/test?size=10", 134 Pagination { 135 size: Some(10), 136 page: None, 137 }, 138 ) 139 .await; 140 141 check( 142 "http://example.com/test?size=10&page=20", 143 Pagination { 144 size: Some(10), 145 page: Some(20), 146 }, 147 ) 148 .await; 149 } 150 151 #[crate::test] correct_rejection_status_code()152 async fn correct_rejection_status_code() { 153 #[derive(Deserialize)] 154 #[allow(dead_code)] 155 struct Params { 156 n: i32, 157 } 158 159 async fn handler(_: Query<Params>) {} 160 161 let app = Router::new().route("/", get(handler)); 162 let client = TestClient::new(app); 163 164 let res = client.get("/?n=hi").send().await; 165 assert_eq!(res.status(), StatusCode::BAD_REQUEST); 166 } 167 168 #[test] test_try_from_uri()169 fn test_try_from_uri() { 170 #[derive(Deserialize)] 171 struct TestQueryParams { 172 foo: String, 173 bar: u32, 174 } 175 let uri: Uri = "http://example.com/path?foo=hello&bar=42".parse().unwrap(); 176 let result: Query<TestQueryParams> = Query::try_from_uri(&uri).unwrap(); 177 assert_eq!(result.foo, String::from("hello")); 178 assert_eq!(result.bar, 42); 179 } 180 181 #[test] test_try_from_uri_with_invalid_query()182 fn test_try_from_uri_with_invalid_query() { 183 #[derive(Deserialize)] 184 struct TestQueryParams { 185 _foo: String, 186 _bar: u32, 187 } 188 let uri: Uri = "http://example.com/path?foo=hello&bar=invalid" 189 .parse() 190 .unwrap(); 191 let result: Result<Query<TestQueryParams>, _> = Query::try_from_uri(&uri); 192 193 assert!(result.is_err()); 194 } 195 } 196