use crate::response::{IntoResponse, Response}; use axum_core::extract::{FromRequest, FromRequestParts}; use futures_util::future::BoxFuture; use http::Request; use std::{ any::type_name, convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin, task::{Context, Poll}, }; use tower::{util::BoxCloneService, ServiceBuilder}; use tower_layer::Layer; use tower_service::Service; /// Create a middleware from an async function. /// /// `from_fn` requires the function given to /// /// 1. Be an `async fn`. /// 2. Take one or more [extractors] as the first arguments. /// 3. Take [`Next`](Next) as the final argument. /// 4. Return something that implements [`IntoResponse`]. /// /// Note that this function doesn't support extracting [`State`]. For that, use [`from_fn_with_state`]. /// /// # Example /// /// ```rust /// use axum::{ /// Router, /// http::{self, Request}, /// routing::get, /// response::Response, /// middleware::{self, Next}, /// }; /// /// async fn my_middleware( /// request: Request, /// next: Next, /// ) -> Response { /// // do something with `request`... /// /// let response = next.run(request).await; /// /// // do something with `response`... /// /// response /// } /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .layer(middleware::from_fn(my_middleware)); /// # let app: Router = app; /// ``` /// /// # Running extractors /// /// ```rust /// use axum::{ /// Router, /// extract::TypedHeader, /// http::StatusCode, /// headers::authorization::{Authorization, Bearer}, /// http::Request, /// middleware::{self, Next}, /// response::Response, /// routing::get, /// }; /// /// async fn auth( /// // run the `TypedHeader` extractor /// TypedHeader(auth): TypedHeader>, /// // you can also add more extractors here but the last /// // extractor must implement `FromRequest` which /// // `Request` does /// request: Request, /// next: Next, /// ) -> Result { /// if token_is_valid(auth.token()) { /// let response = next.run(request).await; /// Ok(response) /// } else { /// Err(StatusCode::UNAUTHORIZED) /// } /// } /// /// fn token_is_valid(token: &str) -> bool { /// // ... /// # false /// } /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(middleware::from_fn(auth)); /// # let app: Router = app; /// ``` /// /// [extractors]: crate::extract::FromRequest /// [`State`]: crate::extract::State pub fn from_fn(f: F) -> FromFnLayer { from_fn_with_state((), f) } /// Create a middleware from an async function with the given state. /// /// See [`State`](crate::extract::State) for more details about accessing state. /// /// # Example /// /// ```rust /// use axum::{ /// Router, /// http::{Request, StatusCode}, /// routing::get, /// response::{IntoResponse, Response}, /// middleware::{self, Next}, /// extract::State, /// }; /// /// #[derive(Clone)] /// struct AppState { /* ... */ } /// /// async fn my_middleware( /// State(state): State, /// // you can add more extractors here but the last /// // extractor must implement `FromRequest` which /// // `Request` does /// request: Request, /// next: Next, /// ) -> Response { /// // do something with `request`... /// /// let response = next.run(request).await; /// /// // do something with `response`... /// /// response /// } /// /// let state = AppState { /* ... */ }; /// /// let app = Router::new() /// .route("/", get(|| async { /* ... */ })) /// .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware)) /// .with_state(state); /// # let _: axum::Router = app; /// ``` pub fn from_fn_with_state(state: S, f: F) -> FromFnLayer { FromFnLayer { f, state, _extractor: PhantomData, } } /// A [`tower::Layer`] from an async function. /// /// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s. /// /// Created with [`from_fn`]. See that function for more details. #[must_use] pub struct FromFnLayer { f: F, state: S, _extractor: PhantomData T>, } impl Clone for FromFnLayer where F: Clone, S: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), state: self.state.clone(), _extractor: self._extractor, } } } impl Layer for FromFnLayer where F: Clone, S: Clone, { type Service = FromFn; fn layer(&self, inner: I) -> Self::Service { FromFn { f: self.f.clone(), state: self.state.clone(), inner, _extractor: PhantomData, } } } impl fmt::Debug for FromFnLayer where S: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromFnLayer") // Write out the type name, without quoting it as `&type_name::()` would .field("f", &format_args!("{}", type_name::())) .field("state", &self.state) .finish() } } /// A middleware created from an async function. /// /// Created with [`from_fn`]. See that function for more details. pub struct FromFn { f: F, inner: I, state: S, _extractor: PhantomData T>, } impl Clone for FromFn where F: Clone, I: Clone, S: Clone, { fn clone(&self) -> Self { Self { f: self.f.clone(), inner: self.inner.clone(), state: self.state.clone(), _extractor: self._extractor, } } } macro_rules! impl_service { ( [$($ty:ident),*], $last:ident ) => { #[allow(non_snake_case, unused_mut)] impl Service> for FromFn where F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static, $( $ty: FromRequestParts + Send, )* $last: FromRequest + Send, Fut: Future + Send + 'static, Out: IntoResponse + 'static, I: Service, Error = Infallible> + Clone + Send + 'static, I::Response: IntoResponse, I::Future: Send + 'static, B: Send + 'static, S: Clone + Send + Sync + 'static, { type Response = Response; type Error = Infallible; type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let not_ready_inner = self.inner.clone(); let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); let mut f = self.f.clone(); let state = self.state.clone(); let future = Box::pin(async move { let (mut parts, body) = req.into_parts(); $( let $ty = match $ty::from_request_parts(&mut parts, &state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; )* let req = Request::from_parts(parts, body); let $last = match $last::from_request(req, &state).await { Ok(value) => value, Err(rejection) => return rejection.into_response(), }; let inner = ServiceBuilder::new() .boxed_clone() .map_response(IntoResponse::into_response) .service(ready_inner); let next = Next { inner }; f($($ty,)* $last, next).await.into_response() }); ResponseFuture { inner: future } } } }; } all_the_tuples!(impl_service); impl fmt::Debug for FromFn where S: fmt::Debug, I: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromFnLayer") .field("f", &format_args!("{}", type_name::())) .field("inner", &self.inner) .field("state", &self.state) .finish() } } /// The remainder of a middleware stack, including the handler. pub struct Next { inner: BoxCloneService, Response, Infallible>, } impl Next { /// Execute the remaining middleware stack. pub async fn run(mut self, req: Request) -> Response { match self.inner.call(req).await { Ok(res) => res, Err(err) => match err {}, } } } impl fmt::Debug for Next { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FromFnLayer") .field("inner", &self.inner) .finish() } } impl Clone for Next { fn clone(&self) -> Self { Self { inner: self.inner.clone(), } } } impl Service> for Next { type Response = Response; type Error = Infallible; type Future = Pin> + Send>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { self.inner.call(req) } } /// Response future for [`FromFn`]. pub struct ResponseFuture { inner: BoxFuture<'static, Response>, } impl Future for ResponseFuture { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.inner.as_mut().poll(cx).map(Ok) } } impl fmt::Debug for ResponseFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ResponseFuture").finish() } } #[cfg(test)] mod tests { use super::*; use crate::{body::Body, routing::get, Router}; use http::{HeaderMap, StatusCode}; use tower::ServiceExt; #[crate::test] async fn basic() { async fn insert_header(mut req: Request, next: Next) -> impl IntoResponse { req.headers_mut() .insert("x-axum-test", "ok".parse().unwrap()); next.run(req).await } async fn handle(headers: HeaderMap) -> String { headers["x-axum-test"].to_str().unwrap().to_owned() } let app = Router::new() .route("/", get(handle)) .layer(from_fn(insert_header)); let res = app .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(res.status(), StatusCode::OK); let body = hyper::body::to_bytes(res).await.unwrap(); assert_eq!(&body[..], b"ok"); } }