1 use crate::{
2     extract::FromRequestParts,
3     response::{IntoResponse, Response},
4 };
5 use futures_util::{future::BoxFuture, ready};
6 use http::Request;
7 use pin_project_lite::pin_project;
8 use std::{
9     fmt,
10     future::Future,
11     marker::PhantomData,
12     pin::Pin,
13     task::{Context, Poll},
14 };
15 use tower_layer::Layer;
16 use tower_service::Service;
17 
18 /// Create a middleware from an extractor.
19 ///
20 /// If the extractor succeeds the value will be discarded and the inner service
21 /// will be called. If the extractor fails the rejection will be returned and
22 /// the inner service will _not_ be called.
23 ///
24 /// This can be used to perform validation of requests if the validation doesn't
25 /// produce any useful output, and run the extractor for several handlers
26 /// without repeating it in the function signature.
27 ///
28 /// Note that if the extractor consumes the request body, as `String` or
29 /// [`Bytes`] does, an empty body will be left in its place. Thus wont be
30 /// accessible to subsequent extractors or handlers.
31 ///
32 /// # Example
33 ///
34 /// ```rust
35 /// use axum::{
36 ///     extract::FromRequestParts,
37 ///     middleware::from_extractor,
38 ///     routing::{get, post},
39 ///     Router,
40 ///     http::{header, StatusCode, request::Parts},
41 /// };
42 /// use async_trait::async_trait;
43 ///
44 /// // An extractor that performs authorization.
45 /// struct RequireAuth;
46 ///
47 /// #[async_trait]
48 /// impl<S> FromRequestParts<S> for RequireAuth
49 /// where
50 ///     S: Send + Sync,
51 /// {
52 ///     type Rejection = StatusCode;
53 ///
54 ///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
55 ///         let auth_header = parts
56 ///             .headers
57 ///             .get(header::AUTHORIZATION)
58 ///             .and_then(|value| value.to_str().ok());
59 ///
60 ///         match auth_header {
61 ///             Some(auth_header) if token_is_valid(auth_header) => {
62 ///                 Ok(Self)
63 ///             }
64 ///             _ => Err(StatusCode::UNAUTHORIZED),
65 ///         }
66 ///     }
67 /// }
68 ///
69 /// fn token_is_valid(token: &str) -> bool {
70 ///     // ...
71 ///     # false
72 /// }
73 ///
74 /// async fn handler() {
75 ///     // If we get here the request has been authorized
76 /// }
77 ///
78 /// async fn other_handler() {
79 ///     // If we get here the request has been authorized
80 /// }
81 ///
82 /// let app = Router::new()
83 ///     .route("/", get(handler))
84 ///     .route("/foo", post(other_handler))
85 ///     // The extractor will run before all routes
86 ///     .route_layer(from_extractor::<RequireAuth>());
87 /// # async {
88 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
89 /// # };
90 /// ```
91 ///
92 /// [`Bytes`]: bytes::Bytes
from_extractor<E>() -> FromExtractorLayer<E, ()>93 pub fn from_extractor<E>() -> FromExtractorLayer<E, ()> {
94     from_extractor_with_state(())
95 }
96 
97 /// Create a middleware from an extractor with the given state.
98 ///
99 /// See [`State`](crate::extract::State) for more details about accessing state.
from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S>100 pub fn from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S> {
101     FromExtractorLayer {
102         state,
103         _marker: PhantomData,
104     }
105 }
106 
107 /// [`Layer`] that applies [`FromExtractor`] that runs an extractor and
108 /// discards the value.
109 ///
110 /// See [`from_extractor`] for more details.
111 ///
112 /// [`Layer`]: tower::Layer
113 #[must_use]
114 pub struct FromExtractorLayer<E, S> {
115     state: S,
116     _marker: PhantomData<fn() -> E>,
117 }
118 
119 impl<E, S> Clone for FromExtractorLayer<E, S>
120 where
121     S: Clone,
122 {
clone(&self) -> Self123     fn clone(&self) -> Self {
124         Self {
125             state: self.state.clone(),
126             _marker: PhantomData,
127         }
128     }
129 }
130 
131 impl<E, S> fmt::Debug for FromExtractorLayer<E, S>
132 where
133     S: fmt::Debug,
134 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result135     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136         f.debug_struct("FromExtractorLayer")
137             .field("state", &self.state)
138             .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
139             .finish()
140     }
141 }
142 
143 impl<E, T, S> Layer<T> for FromExtractorLayer<E, S>
144 where
145     S: Clone,
146 {
147     type Service = FromExtractor<T, E, S>;
148 
layer(&self, inner: T) -> Self::Service149     fn layer(&self, inner: T) -> Self::Service {
150         FromExtractor {
151             inner,
152             state: self.state.clone(),
153             _extractor: PhantomData,
154         }
155     }
156 }
157 
158 /// Middleware that runs an extractor and discards the value.
159 ///
160 /// See [`from_extractor`] for more details.
161 pub struct FromExtractor<T, E, S> {
162     inner: T,
163     state: S,
164     _extractor: PhantomData<fn() -> E>,
165 }
166 
167 #[test]
traits()168 fn traits() {
169     use crate::test_helpers::*;
170     assert_send::<FromExtractor<(), NotSendSync, ()>>();
171     assert_sync::<FromExtractor<(), NotSendSync, ()>>();
172 }
173 
174 impl<T, E, S> Clone for FromExtractor<T, E, S>
175 where
176     T: Clone,
177     S: Clone,
178 {
clone(&self) -> Self179     fn clone(&self) -> Self {
180         Self {
181             inner: self.inner.clone(),
182             state: self.state.clone(),
183             _extractor: PhantomData,
184         }
185     }
186 }
187 
188 impl<T, E, S> fmt::Debug for FromExtractor<T, E, S>
189 where
190     T: fmt::Debug,
191     S: fmt::Debug,
192 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result193     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194         f.debug_struct("FromExtractor")
195             .field("inner", &self.inner)
196             .field("state", &self.state)
197             .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
198             .finish()
199     }
200 }
201 
202 impl<T, E, B, S> Service<Request<B>> for FromExtractor<T, E, S>
203 where
204     E: FromRequestParts<S> + 'static,
205     B: Send + 'static,
206     T: Service<Request<B>> + Clone,
207     T::Response: IntoResponse,
208     S: Clone + Send + Sync + 'static,
209 {
210     type Response = Response;
211     type Error = T::Error;
212     type Future = ResponseFuture<B, T, E, S>;
213 
214     #[inline]
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>215     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216         self.inner.poll_ready(cx)
217     }
218 
call(&mut self, req: Request<B>) -> Self::Future219     fn call(&mut self, req: Request<B>) -> Self::Future {
220         let state = self.state.clone();
221         let extract_future = Box::pin(async move {
222             let (mut parts, body) = req.into_parts();
223             let extracted = E::from_request_parts(&mut parts, &state).await;
224             let req = Request::from_parts(parts, body);
225             (req, extracted)
226         });
227 
228         ResponseFuture {
229             state: State::Extracting {
230                 future: extract_future,
231             },
232             svc: Some(self.inner.clone()),
233         }
234     }
235 }
236 
237 pin_project! {
238     /// Response future for [`FromExtractor`].
239     #[allow(missing_debug_implementations)]
240     pub struct ResponseFuture<B, T, E, S>
241     where
242         E: FromRequestParts<S>,
243         T: Service<Request<B>>,
244     {
245         #[pin]
246         state: State<B, T, E, S>,
247         svc: Option<T>,
248     }
249 }
250 
251 pin_project! {
252     #[project = StateProj]
253     enum State<B, T, E, S>
254     where
255         E: FromRequestParts<S>,
256         T: Service<Request<B>>,
257     {
258         Extracting {
259             future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
260         },
261         Call { #[pin] future: T::Future },
262     }
263 }
264 
265 impl<B, T, E, S> Future for ResponseFuture<B, T, E, S>
266 where
267     E: FromRequestParts<S>,
268     T: Service<Request<B>>,
269     T::Response: IntoResponse,
270 {
271     type Output = Result<Response, T::Error>;
272 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>273     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
274         loop {
275             let mut this = self.as_mut().project();
276 
277             let new_state = match this.state.as_mut().project() {
278                 StateProj::Extracting { future } => {
279                     let (req, extracted) = ready!(future.as_mut().poll(cx));
280 
281                     match extracted {
282                         Ok(_) => {
283                             let mut svc = this.svc.take().expect("future polled after completion");
284                             let future = svc.call(req);
285                             State::Call { future }
286                         }
287                         Err(err) => {
288                             let res = err.into_response();
289                             return Poll::Ready(Ok(res));
290                         }
291                     }
292                 }
293                 StateProj::Call { future } => {
294                     return future
295                         .poll(cx)
296                         .map(|result| result.map(IntoResponse::into_response));
297                 }
298             };
299 
300             this.state.set(new_state);
301         }
302     }
303 }
304 
305 #[cfg(test)]
306 mod tests {
307     use super::*;
308     use crate::{async_trait, handler::Handler, routing::get, test_helpers::*, Router};
309     use axum_core::extract::FromRef;
310     use http::{header, request::Parts, StatusCode};
311     use tower_http::limit::RequestBodyLimitLayer;
312 
313     #[crate::test]
test_from_extractor()314     async fn test_from_extractor() {
315         #[derive(Clone)]
316         struct Secret(&'static str);
317 
318         struct RequireAuth;
319 
320         #[async_trait::async_trait]
321         impl<S> FromRequestParts<S> for RequireAuth
322         where
323             S: Send + Sync,
324             Secret: FromRef<S>,
325         {
326             type Rejection = StatusCode;
327 
328             async fn from_request_parts(
329                 parts: &mut Parts,
330                 state: &S,
331             ) -> Result<Self, Self::Rejection> {
332                 let Secret(secret) = Secret::from_ref(state);
333                 if let Some(auth) = parts
334                     .headers
335                     .get(header::AUTHORIZATION)
336                     .and_then(|v| v.to_str().ok())
337                 {
338                     if auth == secret {
339                         return Ok(Self);
340                     }
341                 }
342 
343                 Err(StatusCode::UNAUTHORIZED)
344             }
345         }
346 
347         async fn handler() {}
348 
349         let state = Secret("secret");
350         let app = Router::new().route(
351             "/",
352             get(handler.layer(from_extractor_with_state::<RequireAuth, _>(state))),
353         );
354 
355         let client = TestClient::new(app);
356 
357         let res = client.get("/").send().await;
358         assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
359 
360         let res = client
361             .get("/")
362             .header(http::header::AUTHORIZATION, "secret")
363             .send()
364             .await;
365         assert_eq!(res.status(), StatusCode::OK);
366     }
367 
368     // just needs to compile
369     #[allow(dead_code)]
works_with_request_body_limit()370     fn works_with_request_body_limit() {
371         struct MyExtractor;
372 
373         #[async_trait]
374         impl<S> FromRequestParts<S> for MyExtractor
375         where
376             S: Send + Sync,
377         {
378             type Rejection = std::convert::Infallible;
379 
380             async fn from_request_parts(
381                 _parts: &mut Parts,
382                 _state: &S,
383             ) -> Result<Self, Self::Rejection> {
384                 unimplemented!()
385             }
386         }
387 
388         let _: Router = Router::new()
389             .layer(from_extractor::<MyExtractor>())
390             .layer(RequestBodyLimitLayer::new(1));
391     }
392 }
393