1 use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts};
2 use futures_util::future::BoxFuture;
3 use http::Request;
4 use http_body::Limited;
5 
6 mod sealed {
7     pub trait Sealed<B> {}
8     impl<B> Sealed<B> for http::Request<B> {}
9 }
10 
11 /// Extension trait that adds additional methods to [`Request`].
12 pub trait RequestExt<B>: sealed::Sealed<B> + Sized {
13     /// Apply an extractor to this `Request`.
14     ///
15     /// This is just a convenience for `E::from_request(req, &())`.
16     ///
17     /// Note this consumes the request. Use [`RequestExt::extract_parts`] if you're not extracting
18     /// the body and don't want to consume the request.
19     ///
20     /// # Example
21     ///
22     /// ```
23     /// use axum::{
24     ///     async_trait,
25     ///     extract::FromRequest,
26     ///     http::{header::CONTENT_TYPE, Request, StatusCode},
27     ///     response::{IntoResponse, Response},
28     ///     Form, Json, RequestExt,
29     /// };
30     ///
31     /// struct FormOrJson<T>(T);
32     ///
33     /// #[async_trait]
34     /// impl<S, B, T> FromRequest<S, B> for FormOrJson<T>
35     /// where
36     ///     Json<T>: FromRequest<(), B>,
37     ///     Form<T>: FromRequest<(), B>,
38     ///     T: 'static,
39     ///     B: Send + 'static,
40     ///     S: Send + Sync,
41     /// {
42     ///     type Rejection = Response;
43     ///
44     ///     async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
45     ///         let content_type = req
46     ///             .headers()
47     ///             .get(CONTENT_TYPE)
48     ///             .and_then(|value| value.to_str().ok())
49     ///             .ok_or_else(|| StatusCode::BAD_REQUEST.into_response())?;
50     ///
51     ///         if content_type.starts_with("application/json") {
52     ///             let Json(payload) = req
53     ///                 .extract::<Json<T>, _>()
54     ///                 .await
55     ///                 .map_err(|err| err.into_response())?;
56     ///
57     ///             Ok(Self(payload))
58     ///         } else if content_type.starts_with("application/x-www-form-urlencoded") {
59     ///             let Form(payload) = req
60     ///                 .extract::<Form<T>, _>()
61     ///                 .await
62     ///                 .map_err(|err| err.into_response())?;
63     ///
64     ///             Ok(Self(payload))
65     ///         } else {
66     ///             Err(StatusCode::BAD_REQUEST.into_response())
67     ///         }
68     ///     }
69     /// }
70     /// ```
extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>> where E: FromRequest<(), B, M> + 'static, M: 'static71     fn extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>>
72     where
73         E: FromRequest<(), B, M> + 'static,
74         M: 'static;
75 
76     /// Apply an extractor that requires some state to this `Request`.
77     ///
78     /// This is just a convenience for `E::from_request(req, state)`.
79     ///
80     /// Note this consumes the request. Use [`RequestExt::extract_parts_with_state`] if you're not
81     /// extracting the body and don't want to consume the request.
82     ///
83     /// # Example
84     ///
85     /// ```
86     /// use axum::{
87     ///     async_trait,
88     ///     extract::{FromRef, FromRequest},
89     ///     http::Request,
90     ///     RequestExt,
91     /// };
92     ///
93     /// struct MyExtractor {
94     ///     requires_state: RequiresState,
95     /// }
96     ///
97     /// #[async_trait]
98     /// impl<S, B> FromRequest<S, B> for MyExtractor
99     /// where
100     ///     String: FromRef<S>,
101     ///     S: Send + Sync,
102     ///     B: Send + 'static,
103     /// {
104     ///     type Rejection = std::convert::Infallible;
105     ///
106     ///     async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
107     ///         let requires_state = req.extract_with_state::<RequiresState, _, _>(state).await?;
108     ///
109     ///         Ok(Self { requires_state })
110     ///     }
111     /// }
112     ///
113     /// // some extractor that consumes the request body and requires state
114     /// struct RequiresState { /* ... */ }
115     ///
116     /// #[async_trait]
117     /// impl<S, B> FromRequest<S, B> for RequiresState
118     /// where
119     ///     String: FromRef<S>,
120     ///     S: Send + Sync,
121     ///     B: Send + 'static,
122     /// {
123     ///     // ...
124     ///     # type Rejection = std::convert::Infallible;
125     ///     # async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
126     ///     #     todo!()
127     ///     # }
128     /// }
129     /// ```
extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>> where E: FromRequest<S, B, M> + 'static, S: Send + Sync130     fn extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>>
131     where
132         E: FromRequest<S, B, M> + 'static,
133         S: Send + Sync;
134 
135     /// Apply a parts extractor to this `Request`.
136     ///
137     /// This is just a convenience for `E::from_request_parts(parts, state)`.
138     ///
139     /// # Example
140     ///
141     /// ```
142     /// use axum::{
143     ///     async_trait,
144     ///     extract::FromRequest,
145     ///     headers::{authorization::Bearer, Authorization},
146     ///     http::Request,
147     ///     response::{IntoResponse, Response},
148     ///     Json, RequestExt, TypedHeader,
149     /// };
150     ///
151     /// struct MyExtractor<T> {
152     ///     bearer_token: String,
153     ///     payload: T,
154     /// }
155     ///
156     /// #[async_trait]
157     /// impl<S, B, T> FromRequest<S, B> for MyExtractor<T>
158     /// where
159     ///     B: Send + 'static,
160     ///     S: Send + Sync,
161     ///     Json<T>: FromRequest<(), B>,
162     ///     T: 'static,
163     /// {
164     ///     type Rejection = Response;
165     ///
166     ///     async fn from_request(mut req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
167     ///         let TypedHeader(auth_header) = req
168     ///             .extract_parts::<TypedHeader<Authorization<Bearer>>>()
169     ///             .await
170     ///             .map_err(|err| err.into_response())?;
171     ///
172     ///         let Json(payload) = req
173     ///             .extract::<Json<T>, _>()
174     ///             .await
175     ///             .map_err(|err| err.into_response())?;
176     ///
177     ///         Ok(Self {
178     ///             bearer_token: auth_header.token().to_owned(),
179     ///             payload,
180     ///         })
181     ///     }
182     /// }
183     /// ```
extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>> where E: FromRequestParts<()> + 'static184     fn extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
185     where
186         E: FromRequestParts<()> + 'static;
187 
188     /// Apply a parts extractor that requires some state to this `Request`.
189     ///
190     /// This is just a convenience for `E::from_request_parts(parts, state)`.
191     ///
192     /// # Example
193     ///
194     /// ```
195     /// use axum::{
196     ///     async_trait,
197     ///     extract::{FromRef, FromRequest, FromRequestParts},
198     ///     http::{request::Parts, Request},
199     ///     response::{IntoResponse, Response},
200     ///     Json, RequestExt,
201     /// };
202     ///
203     /// struct MyExtractor<T> {
204     ///     requires_state: RequiresState,
205     ///     payload: T,
206     /// }
207     ///
208     /// #[async_trait]
209     /// impl<S, B, T> FromRequest<S, B> for MyExtractor<T>
210     /// where
211     ///     String: FromRef<S>,
212     ///     Json<T>: FromRequest<(), B>,
213     ///     T: 'static,
214     ///     S: Send + Sync,
215     ///     B: Send + 'static,
216     /// {
217     ///     type Rejection = Response;
218     ///
219     ///     async fn from_request(mut req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
220     ///         let requires_state = req
221     ///             .extract_parts_with_state::<RequiresState, _>(state)
222     ///             .await
223     ///             .map_err(|err| err.into_response())?;
224     ///
225     ///         let Json(payload) = req
226     ///             .extract::<Json<T>, _>()
227     ///             .await
228     ///             .map_err(|err| err.into_response())?;
229     ///
230     ///         Ok(Self {
231     ///             requires_state,
232     ///             payload,
233     ///         })
234     ///     }
235     /// }
236     ///
237     /// struct RequiresState {}
238     ///
239     /// #[async_trait]
240     /// impl<S> FromRequestParts<S> for RequiresState
241     /// where
242     ///     String: FromRef<S>,
243     ///     S: Send + Sync,
244     /// {
245     ///     // ...
246     ///     # type Rejection = std::convert::Infallible;
247     ///     # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
248     ///     #     todo!()
249     ///     # }
250     /// }
251     /// ```
extract_parts_with_state<'a, E, S>( &'a mut self, state: &'a S, ) -> BoxFuture<'a, Result<E, E::Rejection>> where E: FromRequestParts<S> + 'static, S: Send + Sync252     fn extract_parts_with_state<'a, E, S>(
253         &'a mut self,
254         state: &'a S,
255     ) -> BoxFuture<'a, Result<E, E::Rejection>>
256     where
257         E: FromRequestParts<S> + 'static,
258         S: Send + Sync;
259 
260     /// Apply the [default body limit](crate::extract::DefaultBodyLimit).
261     ///
262     /// If it is disabled, return the request as-is in `Err`.
with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>>263     fn with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>>;
264 
265     /// Consumes the request, returning the body wrapped in [`Limited`] if a
266     /// [default limit](crate::extract::DefaultBodyLimit) is in place, or not wrapped if the
267     /// default limit is disabled.
into_limited_body(self) -> Result<Limited<B>, B>268     fn into_limited_body(self) -> Result<Limited<B>, B>;
269 }
270 
271 impl<B> RequestExt<B> for Request<B>
272 where
273     B: Send + 'static,
274 {
extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>> where E: FromRequest<(), B, M> + 'static, M: 'static,275     fn extract<E, M>(self) -> BoxFuture<'static, Result<E, E::Rejection>>
276     where
277         E: FromRequest<(), B, M> + 'static,
278         M: 'static,
279     {
280         self.extract_with_state(&())
281     }
282 
extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>> where E: FromRequest<S, B, M> + 'static, S: Send + Sync,283     fn extract_with_state<E, S, M>(self, state: &S) -> BoxFuture<'_, Result<E, E::Rejection>>
284     where
285         E: FromRequest<S, B, M> + 'static,
286         S: Send + Sync,
287     {
288         E::from_request(self, state)
289     }
290 
extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>> where E: FromRequestParts<()> + 'static,291     fn extract_parts<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
292     where
293         E: FromRequestParts<()> + 'static,
294     {
295         self.extract_parts_with_state(&())
296     }
297 
extract_parts_with_state<'a, E, S>( &'a mut self, state: &'a S, ) -> BoxFuture<'a, Result<E, E::Rejection>> where E: FromRequestParts<S> + 'static, S: Send + Sync,298     fn extract_parts_with_state<'a, E, S>(
299         &'a mut self,
300         state: &'a S,
301     ) -> BoxFuture<'a, Result<E, E::Rejection>>
302     where
303         E: FromRequestParts<S> + 'static,
304         S: Send + Sync,
305     {
306         let mut req = Request::new(());
307         *req.version_mut() = self.version();
308         *req.method_mut() = self.method().clone();
309         *req.uri_mut() = self.uri().clone();
310         *req.headers_mut() = std::mem::take(self.headers_mut());
311         *req.extensions_mut() = std::mem::take(self.extensions_mut());
312         let (mut parts, _) = req.into_parts();
313 
314         Box::pin(async move {
315             let result = E::from_request_parts(&mut parts, state).await;
316 
317             *self.version_mut() = parts.version;
318             *self.method_mut() = parts.method.clone();
319             *self.uri_mut() = parts.uri.clone();
320             *self.headers_mut() = std::mem::take(&mut parts.headers);
321             *self.extensions_mut() = std::mem::take(&mut parts.extensions);
322 
323             result
324         })
325     }
326 
with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>>327     fn with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>> {
328         // update docs in `axum-core/src/extract/default_body_limit.rs` and
329         // `axum/src/docs/extract.md` if this changes
330         const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb
331 
332         match self.extensions().get::<DefaultBodyLimitKind>().copied() {
333             Some(DefaultBodyLimitKind::Disable) => Err(self),
334             Some(DefaultBodyLimitKind::Limit(limit)) => {
335                 Ok(self.map(|b| http_body::Limited::new(b, limit)))
336             }
337             None => Ok(self.map(|b| http_body::Limited::new(b, DEFAULT_LIMIT))),
338         }
339     }
340 
into_limited_body(self) -> Result<Limited<B>, B>341     fn into_limited_body(self) -> Result<Limited<B>, B> {
342         self.with_limited_body()
343             .map(Request::into_body)
344             .map_err(Request::into_body)
345     }
346 }
347 
348 #[cfg(test)]
349 mod tests {
350     use super::*;
351     use crate::{
352         ext_traits::tests::{RequiresState, State},
353         extract::FromRef,
354     };
355     use async_trait::async_trait;
356     use http::Method;
357     use hyper::Body;
358 
359     #[tokio::test]
extract_without_state()360     async fn extract_without_state() {
361         let req = Request::new(());
362 
363         let method: Method = req.extract().await.unwrap();
364 
365         assert_eq!(method, Method::GET);
366     }
367 
368     #[tokio::test]
extract_body_without_state()369     async fn extract_body_without_state() {
370         let req = Request::new(Body::from("foobar"));
371 
372         let body: String = req.extract().await.unwrap();
373 
374         assert_eq!(body, "foobar");
375     }
376 
377     #[tokio::test]
extract_with_state()378     async fn extract_with_state() {
379         let req = Request::new(());
380 
381         let state = "state".to_owned();
382 
383         let State(extracted_state): State<String> = req.extract_with_state(&state).await.unwrap();
384 
385         assert_eq!(extracted_state, state);
386     }
387 
388     #[tokio::test]
extract_parts_without_state()389     async fn extract_parts_without_state() {
390         let mut req = Request::builder().header("x-foo", "foo").body(()).unwrap();
391 
392         let method: Method = req.extract_parts().await.unwrap();
393 
394         assert_eq!(method, Method::GET);
395         assert_eq!(req.headers()["x-foo"], "foo");
396     }
397 
398     #[tokio::test]
extract_parts_with_state()399     async fn extract_parts_with_state() {
400         let mut req = Request::builder().header("x-foo", "foo").body(()).unwrap();
401 
402         let state = "state".to_owned();
403 
404         let State(extracted_state): State<String> =
405             req.extract_parts_with_state(&state).await.unwrap();
406 
407         assert_eq!(extracted_state, state);
408         assert_eq!(req.headers()["x-foo"], "foo");
409     }
410 
411     // this stuff just needs to compile
412     #[allow(dead_code)]
413     struct WorksForCustomExtractor {
414         method: Method,
415         from_state: String,
416         body: String,
417     }
418 
419     #[async_trait]
420     impl<S, B> FromRequest<S, B> for WorksForCustomExtractor
421     where
422         S: Send + Sync,
423         B: Send + 'static,
424         String: FromRef<S> + FromRequest<(), B>,
425     {
426         type Rejection = <String as FromRequest<(), B>>::Rejection;
427 
from_request(mut req: Request<B>, state: &S) -> Result<Self, Self::Rejection>428         async fn from_request(mut req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
429             let RequiresState(from_state) = req.extract_parts_with_state(state).await.unwrap();
430             let method = req.extract_parts().await.unwrap();
431             let body = req.extract().await?;
432 
433             Ok(Self {
434                 method,
435                 from_state,
436                 body,
437             })
438         }
439     }
440 }
441