1 use crate::extract::FromRequestParts;
2 use futures_util::future::BoxFuture;
3 use http::request::Parts;
4 
5 mod sealed {
6     pub trait Sealed {}
7     impl Sealed for http::request::Parts {}
8 }
9 
10 /// Extension trait that adds additional methods to [`Parts`].
11 pub trait RequestPartsExt: sealed::Sealed + Sized {
12     /// Apply an extractor to this `Parts`.
13     ///
14     /// This is just a convenience for `E::from_request_parts(parts, &())`.
15     ///
16     /// # Example
17     ///
18     /// ```
19     /// use axum::{
20     ///     extract::{Query, TypedHeader, FromRequestParts},
21     ///     response::{Response, IntoResponse},
22     ///     headers::UserAgent,
23     ///     http::request::Parts,
24     ///     RequestPartsExt,
25     ///     async_trait,
26     /// };
27     /// use std::collections::HashMap;
28     ///
29     /// struct MyExtractor {
30     ///     user_agent: String,
31     ///     query_params: HashMap<String, String>,
32     /// }
33     ///
34     /// #[async_trait]
35     /// impl<S> FromRequestParts<S> for MyExtractor
36     /// where
37     ///     S: Send + Sync,
38     /// {
39     ///     type Rejection = Response;
40     ///
41     ///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
42     ///         let user_agent = parts
43     ///             .extract::<TypedHeader<UserAgent>>()
44     ///             .await
45     ///             .map(|user_agent| user_agent.as_str().to_owned())
46     ///             .map_err(|err| err.into_response())?;
47     ///
48     ///         let query_params = parts
49     ///             .extract::<Query<HashMap<String, String>>>()
50     ///             .await
51     ///             .map(|Query(params)| params)
52     ///             .map_err(|err| err.into_response())?;
53     ///
54     ///         Ok(MyExtractor { user_agent, query_params })
55     ///     }
56     /// }
57     /// ```
extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>> where E: FromRequestParts<()> + 'static58     fn extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
59     where
60         E: FromRequestParts<()> + 'static;
61 
62     /// Apply an extractor that requires some state to this `Parts`.
63     ///
64     /// This is just a convenience for `E::from_request_parts(parts, state)`.
65     ///
66     /// # Example
67     ///
68     /// ```
69     /// use axum::{
70     ///     extract::{FromRef, FromRequestParts},
71     ///     response::{Response, IntoResponse},
72     ///     http::request::Parts,
73     ///     RequestPartsExt,
74     ///     async_trait,
75     /// };
76     ///
77     /// struct MyExtractor {
78     ///     requires_state: RequiresState,
79     /// }
80     ///
81     /// #[async_trait]
82     /// impl<S> FromRequestParts<S> for MyExtractor
83     /// where
84     ///     String: FromRef<S>,
85     ///     S: Send + Sync,
86     /// {
87     ///     type Rejection = std::convert::Infallible;
88     ///
89     ///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
90     ///         let requires_state = parts
91     ///             .extract_with_state::<RequiresState, _>(state)
92     ///             .await?;
93     ///
94     ///         Ok(MyExtractor { requires_state })
95     ///     }
96     /// }
97     ///
98     /// struct RequiresState { /* ... */ }
99     ///
100     /// // some extractor that requires a `String` in the state
101     /// #[async_trait]
102     /// impl<S> FromRequestParts<S> for RequiresState
103     /// where
104     ///     String: FromRef<S>,
105     ///     S: Send + Sync,
106     /// {
107     ///     // ...
108     ///     # type Rejection = std::convert::Infallible;
109     ///     # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
110     ///     #     unimplemented!()
111     ///     # }
112     /// }
113     /// ```
extract_with_state<'a, E, S>( &'a mut self, state: &'a S, ) -> BoxFuture<'a, Result<E, E::Rejection>> where E: FromRequestParts<S> + 'static, S: Send + Sync114     fn extract_with_state<'a, E, S>(
115         &'a mut self,
116         state: &'a S,
117     ) -> BoxFuture<'a, Result<E, E::Rejection>>
118     where
119         E: FromRequestParts<S> + 'static,
120         S: Send + Sync;
121 }
122 
123 impl RequestPartsExt for Parts {
extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>> where E: FromRequestParts<()> + 'static,124     fn extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>>
125     where
126         E: FromRequestParts<()> + 'static,
127     {
128         self.extract_with_state(&())
129     }
130 
extract_with_state<'a, E, S>( &'a mut self, state: &'a S, ) -> BoxFuture<'a, Result<E, E::Rejection>> where E: FromRequestParts<S> + 'static, S: Send + Sync,131     fn extract_with_state<'a, E, S>(
132         &'a mut self,
133         state: &'a S,
134     ) -> BoxFuture<'a, Result<E, E::Rejection>>
135     where
136         E: FromRequestParts<S> + 'static,
137         S: Send + Sync,
138     {
139         E::from_request_parts(self, state)
140     }
141 }
142 
143 #[cfg(test)]
144 mod tests {
145     use std::convert::Infallible;
146 
147     use super::*;
148     use crate::{
149         ext_traits::tests::{RequiresState, State},
150         extract::FromRef,
151     };
152     use async_trait::async_trait;
153     use http::{Method, Request};
154 
155     #[tokio::test]
extract_without_state()156     async fn extract_without_state() {
157         let (mut parts, _) = Request::new(()).into_parts();
158 
159         let method: Method = parts.extract().await.unwrap();
160 
161         assert_eq!(method, Method::GET);
162     }
163 
164     #[tokio::test]
extract_with_state()165     async fn extract_with_state() {
166         let (mut parts, _) = Request::new(()).into_parts();
167 
168         let state = "state".to_owned();
169 
170         let State(extracted_state): State<String> = parts
171             .extract_with_state::<State<String>, String>(&state)
172             .await
173             .unwrap();
174 
175         assert_eq!(extracted_state, state);
176     }
177 
178     // this stuff just needs to compile
179     #[allow(dead_code)]
180     struct WorksForCustomExtractor {
181         method: Method,
182         from_state: String,
183     }
184 
185     #[async_trait]
186     impl<S> FromRequestParts<S> for WorksForCustomExtractor
187     where
188         S: Send + Sync,
189         String: FromRef<S>,
190     {
191         type Rejection = Infallible;
192 
from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection>193         async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
194             let RequiresState(from_state) = parts.extract_with_state(state).await?;
195             let method = parts.extract().await?;
196 
197             Ok(Self { method, from_state })
198         }
199     }
200 }
201