1 //! Async functions that can be used to handle requests.
2 //!
3 #![doc = include_str!("../docs/handlers_intro.md")]
4 //!
5 //! Some examples of handlers:
6 //!
7 //! ```rust
8 //! use axum::{body::Bytes, http::StatusCode};
9 //!
10 //! // Handler that immediately returns an empty `200 OK` response.
11 //! async fn unit_handler() {}
12 //!
13 //! // Handler that immediately returns an empty `200 OK` response with a plain
14 //! // text body.
15 //! async fn string_handler() -> String {
16 //!     "Hello, World!".to_string()
17 //! }
18 //!
19 //! // Handler that buffers the request body and returns it.
20 //! //
21 //! // This works because `Bytes` implements `FromRequest`
22 //! // and therefore can be used as an extractor.
23 //! //
24 //! // `String` and `StatusCode` both implement `IntoResponse` and
25 //! // therefore `Result<String, StatusCode>` also implements `IntoResponse`
26 //! async fn echo(body: Bytes) -> Result<String, StatusCode> {
27 //!     if let Ok(string) = String::from_utf8(body.to_vec()) {
28 //!         Ok(string)
29 //!     } else {
30 //!         Err(StatusCode::BAD_REQUEST)
31 //!     }
32 //! }
33 //! ```
34 //!
35 //! Instead of a direct `StatusCode`, it makes sense to use intermediate error type
36 //! that can ultimately be converted to `Response`. This allows using `?` operator
37 //! in handlers. See those examples:
38 //!
39 //! * [`anyhow-error-response`][anyhow] for generic boxed errors
40 //! * [`error-handling-and-dependency-injection`][ehdi] for application-specific detailed errors
41 //!
42 //! [anyhow]: https://github.com/tokio-rs/axum/blob/main/examples/anyhow-error-response/src/main.rs
43 //! [ehdi]: https://github.com/tokio-rs/axum/blob/main/examples/error-handling-and-dependency-injection/src/main.rs
44 //!
45 #![doc = include_str!("../docs/debugging_handler_type_errors.md")]
46 
47 #[cfg(feature = "tokio")]
48 use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
49 use crate::{
50     body::Body,
51     extract::{FromRequest, FromRequestParts},
52     response::{IntoResponse, Response},
53     routing::IntoMakeService,
54 };
55 use http::Request;
56 use std::{convert::Infallible, fmt, future::Future, marker::PhantomData, pin::Pin};
57 use tower::ServiceExt;
58 use tower_layer::Layer;
59 use tower_service::Service;
60 
61 pub mod future;
62 mod service;
63 
64 pub use self::service::HandlerService;
65 
66 /// Trait for async functions that can be used to handle requests.
67 ///
68 /// You shouldn't need to depend on this trait directly. It is automatically
69 /// implemented to closures of the right types.
70 ///
71 /// See the [module docs](crate::handler) for more details.
72 ///
73 /// # Converting `Handler`s into [`Service`]s
74 ///
75 /// To convert `Handler`s into [`Service`]s you have to call either
76 /// [`HandlerWithoutStateExt::into_service`] or [`Handler::with_state`]:
77 ///
78 /// ```
79 /// use tower::Service;
80 /// use axum::{
81 ///     extract::State,
82 ///     body::Body,
83 ///     http::Request,
84 ///     handler::{HandlerWithoutStateExt, Handler},
85 /// };
86 ///
87 /// // this handler doesn't require any state
88 /// async fn one() {}
89 /// // so it can be converted to a service with `HandlerWithoutStateExt::into_service`
90 /// assert_service(one.into_service());
91 ///
92 /// // this handler requires state
93 /// async fn two(_: State<String>) {}
94 /// // so we have to provide it
95 /// let handler_with_state = two.with_state(String::new());
96 /// // which gives us a `Service`
97 /// assert_service(handler_with_state);
98 ///
99 /// // helper to check that a value implements `Service`
100 /// fn assert_service<S>(service: S)
101 /// where
102 ///     S: Service<Request<Body>>,
103 /// {}
104 /// ```
105 #[doc = include_str!("../docs/debugging_handler_type_errors.md")]
106 ///
107 /// # Handlers that aren't functions
108 ///
109 /// The `Handler` trait is also implemented for `T: IntoResponse`. That allows easily returning
110 /// fixed data for routes:
111 ///
112 /// ```
113 /// use axum::{
114 ///     Router,
115 ///     routing::{get, post},
116 ///     Json,
117 ///     http::StatusCode,
118 /// };
119 /// use serde_json::json;
120 ///
121 /// let app = Router::new()
122 ///     // respond with a fixed string
123 ///     .route("/", get("Hello, World!"))
124 ///     // or return some mock data
125 ///     .route("/users", post((
126 ///         StatusCode::CREATED,
127 ///         Json(json!({ "id": 1, "username": "alice" })),
128 ///     )));
129 /// # let _: Router = app;
130 /// ```
131 #[cfg_attr(
132     nightly_error_messages,
133     rustc_on_unimplemented(
134         note = "Consider using `#[axum::debug_handler]` to improve the error message"
135     )
136 )]
137 pub trait Handler<T, S, B = Body>: Clone + Send + Sized + 'static {
138     /// The type of future calling this handler returns.
139     type Future: Future<Output = Response> + Send + 'static;
140 
141     /// Call the handler with the given request.
call(self, req: Request<B>, state: S) -> Self::Future142     fn call(self, req: Request<B>, state: S) -> Self::Future;
143 
144     /// Apply a [`tower::Layer`] to the handler.
145     ///
146     /// All requests to the handler will be processed by the layer's
147     /// corresponding middleware.
148     ///
149     /// This can be used to add additional processing to a request for a single
150     /// handler.
151     ///
152     /// Note this differs from [`routing::Router::layer`](crate::routing::Router::layer)
153     /// which adds a middleware to a group of routes.
154     ///
155     /// If you're applying middleware that produces errors you have to handle the errors
156     /// so they're converted into responses. You can learn more about doing that
157     /// [here](crate::error_handling).
158     ///
159     /// # Example
160     ///
161     /// Adding the [`tower::limit::ConcurrencyLimit`] middleware to a handler
162     /// can be done like so:
163     ///
164     /// ```rust
165     /// use axum::{
166     ///     routing::get,
167     ///     handler::Handler,
168     ///     Router,
169     /// };
170     /// use tower::limit::{ConcurrencyLimitLayer, ConcurrencyLimit};
171     ///
172     /// async fn handler() { /* ... */ }
173     ///
174     /// let layered_handler = handler.layer(ConcurrencyLimitLayer::new(64));
175     /// let app = Router::new().route("/", get(layered_handler));
176     /// # async {
177     /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
178     /// # };
179     /// ```
layer<L, NewReqBody>(self, layer: L) -> Layered<L, Self, T, S, B, NewReqBody> where L: Layer<HandlerService<Self, T, S, B>> + Clone, L::Service: Service<Request<NewReqBody>>,180     fn layer<L, NewReqBody>(self, layer: L) -> Layered<L, Self, T, S, B, NewReqBody>
181     where
182         L: Layer<HandlerService<Self, T, S, B>> + Clone,
183         L::Service: Service<Request<NewReqBody>>,
184     {
185         Layered {
186             layer,
187             handler: self,
188             _marker: PhantomData,
189         }
190     }
191 
192     /// Convert the handler into a [`Service`] by providing the state
with_state(self, state: S) -> HandlerService<Self, T, S, B>193     fn with_state(self, state: S) -> HandlerService<Self, T, S, B> {
194         HandlerService::new(self, state)
195     }
196 }
197 
198 impl<F, Fut, Res, S, B> Handler<((),), S, B> for F
199 where
200     F: FnOnce() -> Fut + Clone + Send + 'static,
201     Fut: Future<Output = Res> + Send,
202     Res: IntoResponse,
203     B: Send + 'static,
204 {
205     type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
206 
call(self, _req: Request<B>, _state: S) -> Self::Future207     fn call(self, _req: Request<B>, _state: S) -> Self::Future {
208         Box::pin(async move { self().await.into_response() })
209     }
210 }
211 
212 macro_rules! impl_handler {
213     (
214         [$($ty:ident),*], $last:ident
215     ) => {
216         #[allow(non_snake_case, unused_mut)]
217         impl<F, Fut, S, B, Res, M, $($ty,)* $last> Handler<(M, $($ty,)* $last,), S, B> for F
218         where
219             F: FnOnce($($ty,)* $last,) -> Fut + Clone + Send + 'static,
220             Fut: Future<Output = Res> + Send,
221             B: Send + 'static,
222             S: Send + Sync + 'static,
223             Res: IntoResponse,
224             $( $ty: FromRequestParts<S> + Send, )*
225             $last: FromRequest<S, B, M> + Send,
226         {
227             type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
228 
229             fn call(self, req: Request<B>, state: S) -> Self::Future {
230                 Box::pin(async move {
231                     let (mut parts, body) = req.into_parts();
232                     let state = &state;
233 
234                     $(
235                         let $ty = match $ty::from_request_parts(&mut parts, state).await {
236                             Ok(value) => value,
237                             Err(rejection) => return rejection.into_response(),
238                         };
239                     )*
240 
241                     let req = Request::from_parts(parts, body);
242 
243                     let $last = match $last::from_request(req, state).await {
244                         Ok(value) => value,
245                         Err(rejection) => return rejection.into_response(),
246                     };
247 
248                     let res = self($($ty,)* $last,).await;
249 
250                     res.into_response()
251                 })
252             }
253         }
254     };
255 }
256 
257 all_the_tuples!(impl_handler);
258 
259 mod private {
260     // Marker type for `impl<T: IntoResponse> Handler for T`
261     #[allow(missing_debug_implementations)]
262     pub enum IntoResponseHandler {}
263 }
264 
265 impl<T, S, B> Handler<private::IntoResponseHandler, S, B> for T
266 where
267     T: IntoResponse + Clone + Send + 'static,
268     B: Send + 'static,
269 {
270     type Future = std::future::Ready<Response>;
271 
call(self, _req: Request<B>, _state: S) -> Self::Future272     fn call(self, _req: Request<B>, _state: S) -> Self::Future {
273         std::future::ready(self.into_response())
274     }
275 }
276 
277 /// A [`Service`] created from a [`Handler`] by applying a Tower middleware.
278 ///
279 /// Created with [`Handler::layer`]. See that method for more details.
280 pub struct Layered<L, H, T, S, B, B2> {
281     layer: L,
282     handler: H,
283     _marker: PhantomData<fn() -> (T, S, B, B2)>,
284 }
285 
286 impl<L, H, T, S, B, B2> fmt::Debug for Layered<L, H, T, S, B, B2>
287 where
288     L: fmt::Debug,
289 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result290     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291         f.debug_struct("Layered")
292             .field("layer", &self.layer)
293             .finish()
294     }
295 }
296 
297 impl<L, H, T, S, B, B2> Clone for Layered<L, H, T, S, B, B2>
298 where
299     L: Clone,
300     H: Clone,
301 {
clone(&self) -> Self302     fn clone(&self) -> Self {
303         Self {
304             layer: self.layer.clone(),
305             handler: self.handler.clone(),
306             _marker: PhantomData,
307         }
308     }
309 }
310 
311 impl<H, S, T, L, B, B2> Handler<T, S, B2> for Layered<L, H, T, S, B, B2>
312 where
313     L: Layer<HandlerService<H, T, S, B>> + Clone + Send + 'static,
314     H: Handler<T, S, B>,
315     L::Service: Service<Request<B2>, Error = Infallible> + Clone + Send + 'static,
316     <L::Service as Service<Request<B2>>>::Response: IntoResponse,
317     <L::Service as Service<Request<B2>>>::Future: Send,
318     T: 'static,
319     S: 'static,
320     B: Send + 'static,
321     B2: Send + 'static,
322 {
323     type Future = future::LayeredFuture<B2, L::Service>;
324 
call(self, req: Request<B2>, state: S) -> Self::Future325     fn call(self, req: Request<B2>, state: S) -> Self::Future {
326         use futures_util::future::{FutureExt, Map};
327 
328         let svc = self.handler.with_state(state);
329         let svc = self.layer.layer(svc);
330 
331         let future: Map<
332             _,
333             fn(
334                 Result<
335                     <L::Service as Service<Request<B2>>>::Response,
336                     <L::Service as Service<Request<B2>>>::Error,
337                 >,
338             ) -> _,
339         > = svc.oneshot(req).map(|result| match result {
340             Ok(res) => res.into_response(),
341             Err(err) => match err {},
342         });
343 
344         future::LayeredFuture::new(future)
345     }
346 }
347 
348 /// Extension trait for [`Handler`]s that don't have state.
349 ///
350 /// This provides convenience methods to convert the [`Handler`] into a [`Service`] or [`MakeService`].
351 ///
352 /// [`MakeService`]: tower::make::MakeService
353 pub trait HandlerWithoutStateExt<T, B>: Handler<T, (), B> {
354     /// Convert the handler into a [`Service`] and no state.
into_service(self) -> HandlerService<Self, T, (), B>355     fn into_service(self) -> HandlerService<Self, T, (), B>;
356 
357     /// Convert the handler into a [`MakeService`] and no state.
358     ///
359     /// See [`HandlerService::into_make_service`] for more details.
360     ///
361     /// [`MakeService`]: tower::make::MakeService
into_make_service(self) -> IntoMakeService<HandlerService<Self, T, (), B>>362     fn into_make_service(self) -> IntoMakeService<HandlerService<Self, T, (), B>>;
363 
364     /// Convert the handler into a [`MakeService`] which stores information
365     /// about the incoming connection and has no state.
366     ///
367     /// See [`HandlerService::into_make_service_with_connect_info`] for more details.
368     ///
369     /// [`MakeService`]: tower::make::MakeService
370     #[cfg(feature = "tokio")]
into_make_service_with_connect_info<C>( self, ) -> IntoMakeServiceWithConnectInfo<HandlerService<Self, T, (), B>, C>371     fn into_make_service_with_connect_info<C>(
372         self,
373     ) -> IntoMakeServiceWithConnectInfo<HandlerService<Self, T, (), B>, C>;
374 }
375 
376 impl<H, T, B> HandlerWithoutStateExt<T, B> for H
377 where
378     H: Handler<T, (), B>,
379 {
into_service(self) -> HandlerService<Self, T, (), B>380     fn into_service(self) -> HandlerService<Self, T, (), B> {
381         self.with_state(())
382     }
383 
into_make_service(self) -> IntoMakeService<HandlerService<Self, T, (), B>>384     fn into_make_service(self) -> IntoMakeService<HandlerService<Self, T, (), B>> {
385         self.into_service().into_make_service()
386     }
387 
388     #[cfg(feature = "tokio")]
into_make_service_with_connect_info<C>( self, ) -> IntoMakeServiceWithConnectInfo<HandlerService<Self, T, (), B>, C>389     fn into_make_service_with_connect_info<C>(
390         self,
391     ) -> IntoMakeServiceWithConnectInfo<HandlerService<Self, T, (), B>, C> {
392         self.into_service().into_make_service_with_connect_info()
393     }
394 }
395 
396 #[cfg(test)]
397 mod tests {
398     use super::*;
399     use crate::{body, extract::State, test_helpers::*};
400     use http::StatusCode;
401     use std::time::Duration;
402     use tower_http::{
403         compression::CompressionLayer, limit::RequestBodyLimitLayer,
404         map_request_body::MapRequestBodyLayer, map_response_body::MapResponseBodyLayer,
405         timeout::TimeoutLayer,
406     };
407 
408     #[crate::test]
handler_into_service()409     async fn handler_into_service() {
410         async fn handle(body: String) -> impl IntoResponse {
411             format!("you said: {body}")
412         }
413 
414         let client = TestClient::new(handle.into_service());
415 
416         let res = client.post("/").body("hi there!").send().await;
417         assert_eq!(res.status(), StatusCode::OK);
418         assert_eq!(res.text().await, "you said: hi there!");
419     }
420 
421     #[crate::test]
with_layer_that_changes_request_body_and_state()422     async fn with_layer_that_changes_request_body_and_state() {
423         async fn handle(State(state): State<&'static str>) -> &'static str {
424             state
425         }
426 
427         let svc = handle
428             .layer((
429                 RequestBodyLimitLayer::new(1024),
430                 TimeoutLayer::new(Duration::from_secs(10)),
431                 MapResponseBodyLayer::new(body::boxed),
432                 CompressionLayer::new(),
433             ))
434             .layer(MapRequestBodyLayer::new(body::boxed))
435             .with_state("foo");
436 
437         let client = TestClient::new(svc);
438         let res = client.get("/").send().await;
439         assert_eq!(res.text().await, "foo");
440     }
441 }
442