1 use crate::extract::FromRequestParts; 2 use futures_util::future::BoxFuture; 3 use http::request::Parts; 4 5 mod sealed { 6 pub trait Sealed {} 7 impl Sealed for http::request::Parts {} 8 } 9 10 /// Extension trait that adds additional methods to [`Parts`]. 11 pub trait RequestPartsExt: sealed::Sealed + Sized { 12 /// Apply an extractor to this `Parts`. 13 /// 14 /// This is just a convenience for `E::from_request_parts(parts, &())`. 15 /// 16 /// # Example 17 /// 18 /// ``` 19 /// use axum::{ 20 /// extract::{Query, TypedHeader, FromRequestParts}, 21 /// response::{Response, IntoResponse}, 22 /// headers::UserAgent, 23 /// http::request::Parts, 24 /// RequestPartsExt, 25 /// async_trait, 26 /// }; 27 /// use std::collections::HashMap; 28 /// 29 /// struct MyExtractor { 30 /// user_agent: String, 31 /// query_params: HashMap<String, String>, 32 /// } 33 /// 34 /// #[async_trait] 35 /// impl<S> FromRequestParts<S> for MyExtractor 36 /// where 37 /// S: Send + Sync, 38 /// { 39 /// type Rejection = Response; 40 /// 41 /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { 42 /// let user_agent = parts 43 /// .extract::<TypedHeader<UserAgent>>() 44 /// .await 45 /// .map(|user_agent| user_agent.as_str().to_owned()) 46 /// .map_err(|err| err.into_response())?; 47 /// 48 /// let query_params = parts 49 /// .extract::<Query<HashMap<String, String>>>() 50 /// .await 51 /// .map(|Query(params)| params) 52 /// .map_err(|err| err.into_response())?; 53 /// 54 /// Ok(MyExtractor { user_agent, query_params }) 55 /// } 56 /// } 57 /// ``` extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>> where E: FromRequestParts<()> + 'static58 fn extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>> 59 where 60 E: FromRequestParts<()> + 'static; 61 62 /// Apply an extractor that requires some state to this `Parts`. 63 /// 64 /// This is just a convenience for `E::from_request_parts(parts, state)`. 65 /// 66 /// # Example 67 /// 68 /// ``` 69 /// use axum::{ 70 /// extract::{FromRef, FromRequestParts}, 71 /// response::{Response, IntoResponse}, 72 /// http::request::Parts, 73 /// RequestPartsExt, 74 /// async_trait, 75 /// }; 76 /// 77 /// struct MyExtractor { 78 /// requires_state: RequiresState, 79 /// } 80 /// 81 /// #[async_trait] 82 /// impl<S> FromRequestParts<S> for MyExtractor 83 /// where 84 /// String: FromRef<S>, 85 /// S: Send + Sync, 86 /// { 87 /// type Rejection = std::convert::Infallible; 88 /// 89 /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { 90 /// let requires_state = parts 91 /// .extract_with_state::<RequiresState, _>(state) 92 /// .await?; 93 /// 94 /// Ok(MyExtractor { requires_state }) 95 /// } 96 /// } 97 /// 98 /// struct RequiresState { /* ... */ } 99 /// 100 /// // some extractor that requires a `String` in the state 101 /// #[async_trait] 102 /// impl<S> FromRequestParts<S> for RequiresState 103 /// where 104 /// String: FromRef<S>, 105 /// S: Send + Sync, 106 /// { 107 /// // ... 108 /// # type Rejection = std::convert::Infallible; 109 /// # async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { 110 /// # unimplemented!() 111 /// # } 112 /// } 113 /// ``` extract_with_state<'a, E, S>( &'a mut self, state: &'a S, ) -> BoxFuture<'a, Result<E, E::Rejection>> where E: FromRequestParts<S> + 'static, S: Send + Sync114 fn extract_with_state<'a, E, S>( 115 &'a mut self, 116 state: &'a S, 117 ) -> BoxFuture<'a, Result<E, E::Rejection>> 118 where 119 E: FromRequestParts<S> + 'static, 120 S: Send + Sync; 121 } 122 123 impl RequestPartsExt for Parts { extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>> where E: FromRequestParts<()> + 'static,124 fn extract<E>(&mut self) -> BoxFuture<'_, Result<E, E::Rejection>> 125 where 126 E: FromRequestParts<()> + 'static, 127 { 128 self.extract_with_state(&()) 129 } 130 extract_with_state<'a, E, S>( &'a mut self, state: &'a S, ) -> BoxFuture<'a, Result<E, E::Rejection>> where E: FromRequestParts<S> + 'static, S: Send + Sync,131 fn extract_with_state<'a, E, S>( 132 &'a mut self, 133 state: &'a S, 134 ) -> BoxFuture<'a, Result<E, E::Rejection>> 135 where 136 E: FromRequestParts<S> + 'static, 137 S: Send + Sync, 138 { 139 E::from_request_parts(self, state) 140 } 141 } 142 143 #[cfg(test)] 144 mod tests { 145 use std::convert::Infallible; 146 147 use super::*; 148 use crate::{ 149 ext_traits::tests::{RequiresState, State}, 150 extract::FromRef, 151 }; 152 use async_trait::async_trait; 153 use http::{Method, Request}; 154 155 #[tokio::test] extract_without_state()156 async fn extract_without_state() { 157 let (mut parts, _) = Request::new(()).into_parts(); 158 159 let method: Method = parts.extract().await.unwrap(); 160 161 assert_eq!(method, Method::GET); 162 } 163 164 #[tokio::test] extract_with_state()165 async fn extract_with_state() { 166 let (mut parts, _) = Request::new(()).into_parts(); 167 168 let state = "state".to_owned(); 169 170 let State(extracted_state): State<String> = parts 171 .extract_with_state::<State<String>, String>(&state) 172 .await 173 .unwrap(); 174 175 assert_eq!(extracted_state, state); 176 } 177 178 // this stuff just needs to compile 179 #[allow(dead_code)] 180 struct WorksForCustomExtractor { 181 method: Method, 182 from_state: String, 183 } 184 185 #[async_trait] 186 impl<S> FromRequestParts<S> for WorksForCustomExtractor 187 where 188 S: Send + Sync, 189 String: FromRef<S>, 190 { 191 type Rejection = Infallible; 192 from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection>193 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { 194 let RequiresState(from_state) = parts.extract_with_state(state).await?; 195 let method = parts.extract().await?; 196 197 Ok(Self { method, from_state }) 198 } 199 } 200 } 201