1 use crate::response::{IntoResponse, Response};
2 use axum_core::extract::{FromRequest, FromRequestParts};
3 use futures_util::future::BoxFuture;
4 use http::Request;
5 use std::{
6     any::type_name,
7     convert::Infallible,
8     fmt,
9     future::Future,
10     marker::PhantomData,
11     pin::Pin,
12     task::{Context, Poll},
13 };
14 use tower_layer::Layer;
15 use tower_service::Service;
16 
17 /// Create a middleware from an async function that transforms a request.
18 ///
19 /// This differs from [`tower::util::MapRequest`] in that it allows you to easily run axum-specific
20 /// extractors.
21 ///
22 /// # Example
23 ///
24 /// ```
25 /// use axum::{
26 ///     Router,
27 ///     routing::get,
28 ///     middleware::map_request,
29 ///     http::Request,
30 /// };
31 ///
32 /// async fn set_header<B>(mut request: Request<B>) -> Request<B> {
33 ///     request.headers_mut().insert("x-foo", "foo".parse().unwrap());
34 ///     request
35 /// }
36 ///
37 /// async fn handler<B>(request: Request<B>) {
38 ///     // `request` will have an `x-foo` header
39 /// }
40 ///
41 /// let app = Router::new()
42 ///     .route("/", get(handler))
43 ///     .layer(map_request(set_header));
44 /// # let _: Router = app;
45 /// ```
46 ///
47 /// # Rejecting the request
48 ///
49 /// The function given to `map_request` is allowed to also return a `Result` which can be used to
50 /// reject the request and return a response immediately, without calling the remaining
51 /// middleware.
52 ///
53 /// Specifically the valid return types are:
54 ///
55 /// - `Request<B>`
56 /// - `Result<Request<B>, E> where E:  IntoResponse`
57 ///
58 /// ```
59 /// use axum::{
60 ///     Router,
61 ///     http::{Request, StatusCode},
62 ///     routing::get,
63 ///     middleware::map_request,
64 /// };
65 ///
66 /// async fn auth<B>(request: Request<B>) -> Result<Request<B>, StatusCode> {
67 ///     let auth_header = request.headers()
68 ///         .get(http::header::AUTHORIZATION)
69 ///         .and_then(|header| header.to_str().ok());
70 ///
71 ///     match auth_header {
72 ///         Some(auth_header) if token_is_valid(auth_header) => Ok(request),
73 ///         _ => Err(StatusCode::UNAUTHORIZED),
74 ///     }
75 /// }
76 ///
77 /// fn token_is_valid(token: &str) -> bool {
78 ///     // ...
79 ///     # false
80 /// }
81 ///
82 /// let app = Router::new()
83 ///     .route("/", get(|| async { /* ... */ }))
84 ///     .route_layer(map_request(auth));
85 /// # let app: Router = app;
86 /// ```
87 ///
88 /// # Running extractors
89 ///
90 /// ```
91 /// use axum::{
92 ///     Router,
93 ///     routing::get,
94 ///     middleware::map_request,
95 ///     extract::Path,
96 ///     http::Request,
97 /// };
98 /// use std::collections::HashMap;
99 ///
100 /// async fn log_path_params<B>(
101 ///     Path(path_params): Path<HashMap<String, String>>,
102 ///     request: Request<B>,
103 /// ) -> Request<B> {
104 ///     tracing::debug!(?path_params);
105 ///     request
106 /// }
107 ///
108 /// let app = Router::new()
109 ///     .route("/", get(|| async { /* ... */ }))
110 ///     .layer(map_request(log_path_params));
111 /// # let _: Router = app;
112 /// ```
113 ///
114 /// Note that to access state you must use either [`map_request_with_state`].
map_request<F, T>(f: F) -> MapRequestLayer<F, (), T>115 pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T> {
116     map_request_with_state((), f)
117 }
118 
119 /// Create a middleware from an async function that transforms a request, with the given state.
120 ///
121 /// See [`State`](crate::extract::State) for more details about accessing state.
122 ///
123 /// # Example
124 ///
125 /// ```rust
126 /// use axum::{
127 ///     Router,
128 ///     http::{Request, StatusCode},
129 ///     routing::get,
130 ///     response::IntoResponse,
131 ///     middleware::map_request_with_state,
132 ///     extract::State,
133 /// };
134 ///
135 /// #[derive(Clone)]
136 /// struct AppState { /* ... */ }
137 ///
138 /// async fn my_middleware<B>(
139 ///     State(state): State<AppState>,
140 ///     // you can add more extractors here but the last
141 ///     // extractor must implement `FromRequest` which
142 ///     // `Request` does
143 ///     request: Request<B>,
144 /// ) -> Request<B> {
145 ///     // do something with `state` and `request`...
146 ///     request
147 /// }
148 ///
149 /// let state = AppState { /* ... */ };
150 ///
151 /// let app = Router::new()
152 ///     .route("/", get(|| async { /* ... */ }))
153 ///     .route_layer(map_request_with_state(state.clone(), my_middleware))
154 ///     .with_state(state);
155 /// # let _: axum::Router = app;
156 /// ```
map_request_with_state<F, S, T>(state: S, f: F) -> MapRequestLayer<F, S, T>157 pub fn map_request_with_state<F, S, T>(state: S, f: F) -> MapRequestLayer<F, S, T> {
158     MapRequestLayer {
159         f,
160         state,
161         _extractor: PhantomData,
162     }
163 }
164 
165 /// A [`tower::Layer`] from an async function that transforms a request.
166 ///
167 /// Created with [`map_request`]. See that function for more details.
168 #[must_use]
169 pub struct MapRequestLayer<F, S, T> {
170     f: F,
171     state: S,
172     _extractor: PhantomData<fn() -> T>,
173 }
174 
175 impl<F, S, T> Clone for MapRequestLayer<F, S, T>
176 where
177     F: Clone,
178     S: Clone,
179 {
clone(&self) -> Self180     fn clone(&self) -> Self {
181         Self {
182             f: self.f.clone(),
183             state: self.state.clone(),
184             _extractor: self._extractor,
185         }
186     }
187 }
188 
189 impl<S, I, F, T> Layer<I> for MapRequestLayer<F, S, T>
190 where
191     F: Clone,
192     S: Clone,
193 {
194     type Service = MapRequest<F, S, I, T>;
195 
layer(&self, inner: I) -> Self::Service196     fn layer(&self, inner: I) -> Self::Service {
197         MapRequest {
198             f: self.f.clone(),
199             state: self.state.clone(),
200             inner,
201             _extractor: PhantomData,
202         }
203     }
204 }
205 
206 impl<F, S, T> fmt::Debug for MapRequestLayer<F, S, T>
207 where
208     S: fmt::Debug,
209 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result210     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211         f.debug_struct("MapRequestLayer")
212             // Write out the type name, without quoting it as `&type_name::<F>()` would
213             .field("f", &format_args!("{}", type_name::<F>()))
214             .field("state", &self.state)
215             .finish()
216     }
217 }
218 
219 /// A middleware created from an async function that transforms a request.
220 ///
221 /// Created with [`map_request`]. See that function for more details.
222 pub struct MapRequest<F, S, I, T> {
223     f: F,
224     inner: I,
225     state: S,
226     _extractor: PhantomData<fn() -> T>,
227 }
228 
229 impl<F, S, I, T> Clone for MapRequest<F, S, I, T>
230 where
231     F: Clone,
232     I: Clone,
233     S: Clone,
234 {
clone(&self) -> Self235     fn clone(&self) -> Self {
236         Self {
237             f: self.f.clone(),
238             inner: self.inner.clone(),
239             state: self.state.clone(),
240             _extractor: self._extractor,
241         }
242     }
243 }
244 
245 macro_rules! impl_service {
246     (
247         [$($ty:ident),*], $last:ident
248     ) => {
249         #[allow(non_snake_case, unused_mut)]
250         impl<F, Fut, S, I, B, $($ty,)* $last> Service<Request<B>> for MapRequest<F, S, I, ($($ty,)* $last,)>
251         where
252             F: FnMut($($ty,)* $last) -> Fut + Clone + Send + 'static,
253             $( $ty: FromRequestParts<S> + Send, )*
254             $last: FromRequest<S, B> + Send,
255             Fut: Future + Send + 'static,
256             Fut::Output: IntoMapRequestResult<B> + Send + 'static,
257             I: Service<Request<B>, Error = Infallible>
258                 + Clone
259                 + Send
260                 + 'static,
261             I::Response: IntoResponse,
262             I::Future: Send + 'static,
263             B: Send + 'static,
264             S: Clone + Send + Sync + 'static,
265         {
266             type Response = Response;
267             type Error = Infallible;
268             type Future = ResponseFuture;
269 
270             fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
271                 self.inner.poll_ready(cx)
272             }
273 
274             fn call(&mut self, req: Request<B>) -> Self::Future {
275                 let not_ready_inner = self.inner.clone();
276                 let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
277 
278                 let mut f = self.f.clone();
279                 let state = self.state.clone();
280 
281                 let future = Box::pin(async move {
282                     let (mut parts, body) = req.into_parts();
283 
284                     $(
285                         let $ty = match $ty::from_request_parts(&mut parts, &state).await {
286                             Ok(value) => value,
287                             Err(rejection) => return rejection.into_response(),
288                         };
289                     )*
290 
291                     let req = Request::from_parts(parts, body);
292 
293                     let $last = match $last::from_request(req, &state).await {
294                         Ok(value) => value,
295                         Err(rejection) => return rejection.into_response(),
296                     };
297 
298                     match f($($ty,)* $last).await.into_map_request_result() {
299                         Ok(req) => {
300                             ready_inner.call(req).await.into_response()
301                         }
302                         Err(res) => {
303                             res
304                         }
305                     }
306                 });
307 
308                 ResponseFuture {
309                     inner: future
310                 }
311             }
312         }
313     };
314 }
315 
316 all_the_tuples!(impl_service);
317 
318 impl<F, S, I, T> fmt::Debug for MapRequest<F, S, I, T>
319 where
320     S: fmt::Debug,
321     I: fmt::Debug,
322 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result323     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324         f.debug_struct("MapRequest")
325             .field("f", &format_args!("{}", type_name::<F>()))
326             .field("inner", &self.inner)
327             .field("state", &self.state)
328             .finish()
329     }
330 }
331 
332 /// Response future for [`MapRequest`].
333 pub struct ResponseFuture {
334     inner: BoxFuture<'static, Response>,
335 }
336 
337 impl Future for ResponseFuture {
338     type Output = Result<Response, Infallible>;
339 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>340     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
341         self.inner.as_mut().poll(cx).map(Ok)
342     }
343 }
344 
345 impl fmt::Debug for ResponseFuture {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result346     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347         f.debug_struct("ResponseFuture").finish()
348     }
349 }
350 
351 mod private {
352     use crate::{http::Request, response::IntoResponse};
353 
354     pub trait Sealed<B> {}
355     impl<B, E> Sealed<B> for Result<Request<B>, E> where E: IntoResponse {}
356     impl<B> Sealed<B> for Request<B> {}
357 }
358 
359 /// Trait implemented by types that can be returned from [`map_request`],
360 /// [`map_request_with_state`].
361 ///
362 /// This trait is sealed such that it cannot be implemented outside this crate.
363 pub trait IntoMapRequestResult<B>: private::Sealed<B> {
364     /// Perform the conversion.
into_map_request_result(self) -> Result<Request<B>, Response>365     fn into_map_request_result(self) -> Result<Request<B>, Response>;
366 }
367 
368 impl<B, E> IntoMapRequestResult<B> for Result<Request<B>, E>
369 where
370     E: IntoResponse,
371 {
into_map_request_result(self) -> Result<Request<B>, Response>372     fn into_map_request_result(self) -> Result<Request<B>, Response> {
373         self.map_err(IntoResponse::into_response)
374     }
375 }
376 
377 impl<B> IntoMapRequestResult<B> for Request<B> {
into_map_request_result(self) -> Result<Request<B>, Response>378     fn into_map_request_result(self) -> Result<Request<B>, Response> {
379         Ok(self)
380     }
381 }
382 
383 #[cfg(test)]
384 mod tests {
385     use super::*;
386     use crate::{routing::get, test_helpers::TestClient, Router};
387     use http::{HeaderMap, StatusCode};
388 
389     #[crate::test]
works()390     async fn works() {
391         async fn add_header<B>(mut req: Request<B>) -> Request<B> {
392             req.headers_mut().insert("x-foo", "foo".parse().unwrap());
393             req
394         }
395 
396         async fn handler(headers: HeaderMap) -> Response {
397             headers["x-foo"]
398                 .to_str()
399                 .unwrap()
400                 .to_owned()
401                 .into_response()
402         }
403 
404         let app = Router::new()
405             .route("/", get(handler))
406             .layer(map_request(add_header));
407         let client = TestClient::new(app);
408 
409         let res = client.get("/").send().await;
410 
411         assert_eq!(res.text().await, "foo");
412     }
413 
414     #[crate::test]
works_for_short_circutting()415     async fn works_for_short_circutting() {
416         async fn add_header<B>(_req: Request<B>) -> Result<Request<B>, (StatusCode, &'static str)> {
417             Err((StatusCode::INTERNAL_SERVER_ERROR, "something went wrong"))
418         }
419 
420         async fn handler(_headers: HeaderMap) -> Response {
421             unreachable!()
422         }
423 
424         let app = Router::new()
425             .route("/", get(handler))
426             .layer(map_request(add_header));
427         let client = TestClient::new(app);
428 
429         let res = client.get("/").send().await;
430 
431         assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
432         assert_eq!(res.text().await, "something went wrong");
433     }
434 }
435