use crate::response::{IntoResponse, Response}; use axum_core::extract::{FromRequest, FromRequestParts}; use futures_util::future::BoxFuture; use http::Request; use std::{ any::type_name, convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin, task::{Context, Poll}, }; use tower_layer::Layer; use tower_service::Service; /// Create a middleware from an async function that transforms a request. /// /// This differs from [`tower::util::MapRequest`] in that it allows you to easily run axum-specific /// extractors. /// /// # Example /// /// ``` /// use axum::{ /// Router, /// routing::get, /// middleware::map_request, /// http::Request, /// }; /// /// async fn set_header(mut request: Request) -> Request { /// request.headers_mut().insert("x-foo", "foo".parse().unwrap()); /// request /// } /// /// async fn handler(request: Request) { /// // `request` will have an `x-foo` header /// } /// /// let app = Router::new() /// .route("/", get(handler)) /// .layer(map_request(set_header)); /// # let _: Router = app; /// ``` /// /// # Rejecting the request /// /// The function given to `map_request` is allowed to also return a `Result` which can be used to /// reject the request and return a response immediately, without calling the remaining /// middleware. /// /// Specifically the valid return types are: /// /// - `Request` /// - `Result, E> where E: IntoResponse` /// /// ``` /// use axum::{ /// Router, /// http::{Request, StatusCode}, /// routing::get, /// middleware::map_request, /// }; /// /// async fn auth(request: Request) -> Result, StatusCode> { /// let auth_header = request.headers() /// .get(http::header::AUTHORIZATION) /// .and_then(|header| header.to_str().ok()); /// /// match auth_header { /// Some(auth_header) if token_is_valid(auth_header) => Ok(request), /// _ => Err(StatusCode::UNAUTHORIZED), /// } /// } /// /// fn token_is_valid(token: &str) -> bool { /// // ... /// # false /// } /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(map_request(auth)); /// # let app: Router = app; /// ``` /// /// # Running extractors /// /// ``` /// use axum::{ /// Router, /// routing::get, /// middleware::map_request, /// extract::Path, /// http::Request, /// }; /// use std::collections::HashMap; /// /// async fn log_path_params( /// Path(path_params): Path>, /// request: Request, /// ) -> Request { /// tracing::debug!(?path_params); /// request /// } /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .layer(map_request(log_path_params)); /// # let _: Router = app; /// ``` /// /// Note that to access state you must use either [`map_request_with_state`]. pub fn map_request(f: F) -> MapRequestLayer { map_request_with_state((), f) } /// Create a middleware from an async function that transforms a request, with the given state. /// /// See [`State`](crate::extract::State) for more details about accessing state. /// /// # Example /// /// ```rust /// use axum::{ /// Router, /// http::{Request, StatusCode}, /// routing::get, /// response::IntoResponse, /// middleware::map_request_with_state, /// extract::State, /// }; /// /// #[derive(Clone)] /// struct AppState { /* ... */ } /// /// async fn my_middleware( /// State(state): State, /// // you can add more extractors here but the last /// // extractor must implement `FromRequest` which /// // `Request` does /// request: Request, /// ) -> Request { /// // do something with `state` and `request`... /// request /// } /// /// let state = AppState { /* ... */ }; /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(map_request_with_state(state.clone(), my_middleware)) /// .with_state(state); /// # let _: axum::Router = app; /// ``` pub fn map_request_with_state(state: S, f: F) -> MapRequestLayer { MapRequestLayer { f, state, _extractor: PhantomData, } } /// A [`tower::Layer`] from an async function that transforms a request. /// /// Created with [`map_request`]. See that function for more details. #[must_use] pub struct MapRequestLayer { f: F, state: S, _extractor: PhantomData T>, } impl Clone for MapRequestLayer where F: Clone, S: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), state: self.state.clone(), _extractor: self._extractor, } } } impl Layer for MapRequestLayer where F: Clone, S: Clone, { type Service = MapRequest; fn layer(&self, inner: I) -> Self::Service { MapRequest { f: self.f.clone(), state: self.state.clone(), inner, _extractor: PhantomData, } } } impl fmt::Debug for MapRequestLayer where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapRequestLayer") // Write out the type name, without quoting it as `&type_name::()` would .field("f", &format_args!("{}", type_name::())) .field("state", &self.state) .finish() } } /// A middleware created from an async function that transforms a request. /// /// Created with [`map_request`]. See that function for more details. pub struct MapRequest { f: F, inner: I, state: S, _extractor: PhantomData T>, } impl Clone for MapRequest where F: Clone, I: Clone, S: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), inner: self.inner.clone(), state: self.state.clone(), _extractor: self._extractor, } } } macro_rules! impl_service { ( [$($ty:ident),*], $last:ident ) => { #[allow(non_snake_case, unused_mut)] impl Service> for MapRequest where F: FnMut($($ty,)* $last) -> Fut + Clone + Send + 'static, $( $ty: FromRequestParts + Send, )* $last: FromRequest + Send, Fut: Future + Send + 'static, Fut::Output: IntoMapRequestResult + Send + 'static, I: Service, Error = Infallible> + Clone + Send + 'static, I::Response: IntoResponse, I::Future: Send + 'static, B: Send + 'static, S: Clone + Send + Sync + 'static, { type Response = Response; type Error = Infallible; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let not_ready_inner = self.inner.clone(); let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); let mut f = self.f.clone(); let state = self.state.clone(); let future = Box::pin(async move { let (mut parts, body) = req.into_parts(); $( let $ty = match $ty::from_request_parts(&mut parts, &state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; )* let req = Request::from_parts(parts, body); let $last = match $last::from_request(req, &state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; match f($($ty,)* $last).await.into_map_request_result() { Ok(req) => { ready_inner.call(req).await.into_response() } Err(res) => { res } } }); ResponseFuture { inner: future } } } }; } all_the_tuples!(impl_service); impl fmt::Debug for MapRequest where S: fmt::Debug, I: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("MapRequest") .field("f", &format_args!("{}", type_name::())) .field("inner", &self.inner) .field("state", &self.state) .finish() } } /// Response future for [`MapRequest`]. pub struct ResponseFuture { inner: BoxFuture<'static, Response>, } impl Future for ResponseFuture { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.inner.as_mut().poll(cx).map(Ok) } } impl fmt::Debug for ResponseFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ResponseFuture").finish() } } mod private { use crate::{http::Request, response::IntoResponse}; pub trait Sealed {} impl Sealed for Result, E> where E: IntoResponse {} impl Sealed for Request {} } /// Trait implemented by types that can be returned from [`map_request`], /// [`map_request_with_state`]. /// /// This trait is sealed such that it cannot be implemented outside this crate. pub trait IntoMapRequestResult: private::Sealed { /// Perform the conversion. fn into_map_request_result(self) -> Result, Response>; } impl IntoMapRequestResult for Result, E> where E: IntoResponse, { fn into_map_request_result(self) -> Result, Response> { self.map_err(IntoResponse::into_response) } } impl IntoMapRequestResult for Request { fn into_map_request_result(self) -> Result, Response> { Ok(self) } } #[cfg(test)] mod tests { use super::*; use crate::{routing::get, test_helpers::TestClient, Router}; use http::{HeaderMap, StatusCode}; #[crate::test] async fn works() { async fn add_header(mut req: Request) -> Request { req.headers_mut().insert("x-foo", "foo".parse().unwrap()); req } async fn handler(headers: HeaderMap) -> Response { headers["x-foo"] .to_str() .unwrap() .to_owned() .into_response() } let app = Router::new() .route("/", get(handler)) .layer(map_request(add_header)); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.text().await, "foo"); } #[crate::test] async fn works_for_short_circutting() { async fn add_header(_req: Request) -> Result, (StatusCode, &'static str)> { Err((StatusCode::INTERNAL_SERVER_ERROR, "something went wrong")) } async fn handler(_headers: HeaderMap) -> Response { unreachable!() } let app = Router::new() .route("/", get(handler)) .layer(map_request(add_header)); let client = TestClient::new(app); let res = client.get("/").send().await; assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(res.text().await, "something went wrong"); } }