1 use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts}; 2 use futures_util::future::BoxFuture; 3 use http::Request; 4 use http_body::Limited; 5 6 mod sealed { 7 pub trait Sealed<B> {} 8 impl<B> Sealed<B> for http::Request<B> {} 9 } 10 11 /// Extension trait that adds additional methods to [`Request`]. 12 pub trait RequestExt<B>: sealed::Sealed<B> + Sized { 13 /// Apply an extractor to this `Request`. 14 /// 15 /// This is just a convenience for `E::from_request(req, &())`. 16 /// 17 /// Note this consumes the request. Use [`RequestExt::extract_parts`] if you're not extracting 18 /// the body and don't want to consume the request. 19 /// 20 /// # Example 21 /// 22 /// ``` 23 /// use axum::{ 24 /// async_trait, 25 /// extract::FromRequest, 26 /// http::{header::CONTENT_TYPE, Request, StatusCode}, 27 /// response::{IntoResponse, Response}, 28 /// Form, Json, RequestExt, 29 /// }; 30 /// 31 /// struct FormOrJson<T>(T); 32 /// 33 /// #[async_trait] 34 /// impl<S, B, T> FromRequest<S, B> for FormOrJson<T> 35 /// where 36 /// Json<T>: FromRequest<(), B>, 37 /// Form<T>: FromRequest<(), B>, 38 /// T: 'static, 39 /// B: Send + 'static, 40 /// S: Send + Sync, 41 /// { 42 /// type Rejection = Response; 43 /// 44 /// async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> { 45 /// let content_type = req 46 /// .headers() 47 /// .get(CONTENT_TYPE) 48 /// .and_then(|value| value.to_str().ok()) 49 /// .ok_or_else(|| StatusCode::BAD_REQUEST.into_response())?; 50 /// 51 /// if content_type.starts_with("application/json") { 52 /// let Json(payload) = req 53 /// .extract::<Json<T>, _>() 54 /// .await 55 /// .map_err(|err| err.into_response())?; 56 /// 57 /// Ok(Self(payload)) 58 /// } else if content_type.starts_with("application/x-www-form-urlencoded") { 59 /// let Form(payload) = req 60 /// .extract::<Form<T>, _>() 61 /// .await 62 /// .map_err(|err| err.into_response())?; 63 /// 64 /// Ok(Self(payload)) 65 /// } else { 66 /// Err(StatusCode::BAD_REQUEST.into_response()) 67 /// } 68 /// } 69 /// } 70 /// ``` extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>> where E: FromRequest<(), B, M> + 'static, M: 'static71 fn extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>> 72 where 73 E: FromRequest<(), B, M> + 'static, 74 M: 'static; 75 76 /// Apply an extractor that requires some state to this `Request`. 77 /// 78 /// This is just a convenience for `E::from_request(req, state)`. 79 /// 80 /// Note this consumes the request. Use [`RequestExt::extract_parts_with_state`] if you're not 81 /// extracting the body and don't want to consume the request. 82 /// 83 /// # Example 84 /// 85 /// ``` 86 /// use axum::{ 87 /// async_trait, 88 /// extract::{FromRef, FromRequest}, 89 /// http::Request, 90 /// RequestExt, 91 /// }; 92 /// 93 /// struct MyExtractor { 94 /// requires_state: RequiresState, 95 /// } 96 /// 97 /// #[async_trait] 98 /// impl<S, B> FromRequest<S, B> for MyExtractor 99 /// where 100 /// String: FromRef<S>, 101 /// S: Send + Sync, 102 /// B: Send + 'static, 103 /// { 104 /// type Rejection = std::convert::Infallible; 105 /// 106 /// async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> { 107 /// let requires_state = req.extract_with_state::<RequiresState, _, _>(state).await?; 108 /// 109 /// Ok(Self { requires_state }) 110 /// } 111 /// } 112 /// 113 /// // some extractor that consumes the request body and requires state 114 /// struct RequiresState { /* ... */ } 115 /// 116 /// #[async_trait] 117 /// impl<S, B> FromRequest<S, B> for RequiresState 118 /// where 119 /// String: FromRef<S>, 120 /// S: Send + Sync, 121 /// B: Send + 'static, 122 /// { 123 /// // ... 124 /// # type Rejection = std::convert::Infallible; 125 /// # async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> { 126 /// # todo!() 127 /// # } 128 /// } 129 /// ``` extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>> where E: FromRequest<S, B, M> + 'static, S: Send + Sync130 fn extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>> 131 where 132 E: FromRequest<S, B, M> + 'static, 133 S: Send + Sync; 134 135 /// Apply a parts extractor to this `Request`. 136 /// 137 /// This is just a convenience for `E::from_request_parts(parts, state)`. 138 /// 139 /// # Example 140 /// 141 /// ``` 142 /// use axum::{ 143 /// async_trait, 144 /// extract::FromRequest, 145 /// headers::{authorization::Bearer, Authorization}, 146 /// http::Request, 147 /// response::{IntoResponse, Response}, 148 /// Json, RequestExt, TypedHeader, 149 /// }; 150 /// 151 /// struct MyExtractor<T> { 152 /// bearer_token: String, 153 /// payload: T, 154 /// } 155 /// 156 /// #[async_trait] 157 /// impl<S, B, T> FromRequest<S, B> for MyExtractor<T> 158 /// where 159 /// B: Send + 'static, 160 /// S: Send + Sync, 161 /// Json<T>: FromRequest<(), B>, 162 /// T: 'static, 163 /// { 164 /// type Rejection = Response; 165 /// 166 /// async fn from_request(mut req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> { 167 /// let TypedHeader(auth_header) = req 168 /// .extract_parts::<TypedHeader<Authorization<Bearer>>>() 169 /// .await 170 /// .map_err(|err| err.into_response())?; 171 /// 172 /// let Json(payload) = req 173 /// .extract::<Json<T>, _>() 174 /// .await 175 /// .map_err(|err| err.into_response())?; 176 /// 177 /// Ok(Self { 178 /// bearer_token: auth_header.token().to_owned(), 179 /// payload, 180 /// }) 181 /// } 182 /// } 183 /// ``` extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>> where E: FromRequestParts<()> + 'static184 fn extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>> 185 where 186 E: FromRequestParts<()> + 'static; 187 188 /// Apply a parts extractor that requires some state to this `Request`. 189 /// 190 /// This is just a convenience for `E::from_request_parts(parts, state)`. 191 /// 192 /// # Example 193 /// 194 /// ``` 195 /// use axum::{ 196 /// async_trait, 197 /// extract::{FromRef, FromRequest, FromRequestParts}, 198 /// http::{request::Parts, Request}, 199 /// response::{IntoResponse, Response}, 200 /// Json, RequestExt, 201 /// }; 202 /// 203 /// struct MyExtractor<T> { 204 /// requires_state: RequiresState, 205 /// payload: T, 206 /// } 207 /// 208 /// #[async_trait] 209 /// impl<S, B, T> FromRequest<S, B> for MyExtractor<T> 210 /// where 211 /// String: FromRef<S>, 212 /// Json<T>: FromRequest<(), B>, 213 /// T: 'static, 214 /// S: Send + Sync, 215 /// B: Send + 'static, 216 /// { 217 /// type Rejection = Response; 218 /// 219 /// async fn from_request(mut req: Request<B>, state: &S) -> Result<Self, Self::Rejection> { 220 /// let requires_state = req 221 /// .extract_parts_with_state::<RequiresState, _>(state) 222 /// .await 223 /// .map_err(|err| err.into_response())?; 224 /// 225 /// let Json(payload) = req 226 /// .extract::<Json<T>, _>() 227 /// .await 228 /// .map_err(|err| err.into_response())?; 229 /// 230 /// Ok(Self { 231 /// requires_state, 232 /// payload, 233 /// }) 234 /// } 235 /// } 236 /// 237 /// struct RequiresState {} 238 /// 239 /// #[async_trait] 240 /// impl<S> FromRequestParts<S> for RequiresState 241 /// where 242 /// String: FromRef<S>, 243 /// S: Send + Sync, 244 /// { 245 /// // ... 246 /// # type Rejection = std::convert::Infallible; 247 /// # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { 248 /// # todo!() 249 /// # } 250 /// } 251 /// ``` extract_parts_with_state<'a, E, S>( &'a mut self, state: &'a S, ) -> BoxFuture<'a, Result<E, E::Rejection>> where E: FromRequestParts<S> + 'static, S: Send + Sync252 fn extract_parts_with_state<'a, E, S>( 253 &'a mut self, 254 state: &'a S, 255 ) -> BoxFuture<'a, Result<E, E::Rejection>> 256 where 257 E: FromRequestParts<S> + 'static, 258 S: Send + Sync; 259 260 /// Apply the [default body limit](crate::extract::DefaultBodyLimit). 261 /// 262 /// If it is disabled, return the request as-is in `Err`. with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>>263 fn with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>>; 264 265 /// Consumes the request, returning the body wrapped in [`Limited`] if a 266 /// [default limit](crate::extract::DefaultBodyLimit) is in place, or not wrapped if the 267 /// default limit is disabled. into_limited_body(self) -> Result<Limited<B>, B>268 fn into_limited_body(self) -> Result<Limited<B>, B>; 269 } 270 271 impl<B> RequestExt<B> for Request<B> 272 where 273 B: Send + 'static, 274 { extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>> where E: FromRequest<(), B, M> + 'static, M: 'static,275 fn extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>> 276 where 277 E: FromRequest<(), B, M> + 'static, 278 M: 'static, 279 { 280 self.extract_with_state(&()) 281 } 282 extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>> where E: FromRequest<S, B, M> + 'static, S: Send + Sync,283 fn extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>> 284 where 285 E: FromRequest<S, B, M> + 'static, 286 S: Send + Sync, 287 { 288 E::from_request(self, state) 289 } 290 extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>> where E: FromRequestParts<()> + 'static,291 fn extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>> 292 where 293 E: FromRequestParts<()> + 'static, 294 { 295 self.extract_parts_with_state(&()) 296 } 297 extract_parts_with_state<'a, E, S>( &'a mut self, state: &'a S, ) -> BoxFuture<'a, Result<E, E::Rejection>> where E: FromRequestParts<S> + 'static, S: Send + Sync,298 fn extract_parts_with_state<'a, E, S>( 299 &'a mut self, 300 state: &'a S, 301 ) -> BoxFuture<'a, Result<E, E::Rejection>> 302 where 303 E: FromRequestParts<S> + 'static, 304 S: Send + Sync, 305 { 306 let mut req = Request::new(()); 307 *req.version_mut() = self.version(); 308 *req.method_mut() = self.method().clone(); 309 *req.uri_mut() = self.uri().clone(); 310 *req.headers_mut() = std::mem::take(self.headers_mut()); 311 *req.extensions_mut() = std::mem::take(self.extensions_mut()); 312 let (mut parts, _) = req.into_parts(); 313 314 Box::pin(async move { 315 let result = E::from_request_parts(&mut parts, state).await; 316 317 *self.version_mut() = parts.version; 318 *self.method_mut() = parts.method.clone(); 319 *self.uri_mut() = parts.uri.clone(); 320 *self.headers_mut() = std::mem::take(&mut parts.headers); 321 *self.extensions_mut() = std::mem::take(&mut parts.extensions); 322 323 result 324 }) 325 } 326 with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>>327 fn with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>> { 328 // update docs in `axum-core/src/extract/default_body_limit.rs` and 329 // `axum/src/docs/extract.md` if this changes 330 const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb 331 332 match self.extensions().get::<DefaultBodyLimitKind>().copied() { 333 Some(DefaultBodyLimitKind::Disable) => Err(self), 334 Some(DefaultBodyLimitKind::Limit(limit)) => { 335 Ok(self.map(|b| http_body::Limited::new(b, limit))) 336 } 337 None => Ok(self.map(|b| http_body::Limited::new(b, DEFAULT_LIMIT))), 338 } 339 } 340 into_limited_body(self) -> Result<Limited<B>, B>341 fn into_limited_body(self) -> Result<Limited<B>, B> { 342 self.with_limited_body() 343 .map(Request::into_body) 344 .map_err(Request::into_body) 345 } 346 } 347 348 #[cfg(test)] 349 mod tests { 350 use super::*; 351 use crate::{ 352 ext_traits::tests::{RequiresState, State}, 353 extract::FromRef, 354 }; 355 use async_trait::async_trait; 356 use http::Method; 357 use hyper::Body; 358 359 #[tokio::test] extract_without_state()360 async fn extract_without_state() { 361 let req = Request::new(()); 362 363 let method: Method = req.extract().await.unwrap(); 364 365 assert_eq!(method, Method::GET); 366 } 367 368 #[tokio::test] extract_body_without_state()369 async fn extract_body_without_state() { 370 let req = Request::new(Body::from("foobar")); 371 372 let body: String = req.extract().await.unwrap(); 373 374 assert_eq!(body, "foobar"); 375 } 376 377 #[tokio::test] extract_with_state()378 async fn extract_with_state() { 379 let req = Request::new(()); 380 381 let state = "state".to_owned(); 382 383 let State(extracted_state): State<String> = req.extract_with_state(&state).await.unwrap(); 384 385 assert_eq!(extracted_state, state); 386 } 387 388 #[tokio::test] extract_parts_without_state()389 async fn extract_parts_without_state() { 390 let mut req = Request::builder().header("x-foo", "foo").body(()).unwrap(); 391 392 let method: Method = req.extract_parts().await.unwrap(); 393 394 assert_eq!(method, Method::GET); 395 assert_eq!(req.headers()["x-foo"], "foo"); 396 } 397 398 #[tokio::test] extract_parts_with_state()399 async fn extract_parts_with_state() { 400 let mut req = Request::builder().header("x-foo", "foo").body(()).unwrap(); 401 402 let state = "state".to_owned(); 403 404 let State(extracted_state): State<String> = 405 req.extract_parts_with_state(&state).await.unwrap(); 406 407 assert_eq!(extracted_state, state); 408 assert_eq!(req.headers()["x-foo"], "foo"); 409 } 410 411 // this stuff just needs to compile 412 #[allow(dead_code)] 413 struct WorksForCustomExtractor { 414 method: Method, 415 from_state: String, 416 body: String, 417 } 418 419 #[async_trait] 420 impl<S, B> FromRequest<S, B> for WorksForCustomExtractor 421 where 422 S: Send + Sync, 423 B: Send + 'static, 424 String: FromRef<S> + FromRequest<(), B>, 425 { 426 type Rejection = <String as FromRequest<(), B>>::Rejection; 427 from_request(mut req: Request<B>, state: &S) -> Result<Self, Self::Rejection>428 async fn from_request(mut req: Request<B>, state: &S) -> Result<Self, Self::Rejection> { 429 let RequiresState(from_state) = req.extract_parts_with_state(state).await.unwrap(); 430 let method = req.extract_parts().await.unwrap(); 431 let body = req.extract().await?; 432 433 Ok(Self { 434 method, 435 from_state, 436 body, 437 }) 438 } 439 } 440 } 441