//! Route to services and handlers based on HTTP methods.
use super::{future::InfallibleRouteFuture, IntoMakeService};
#[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{
body::{Body, Bytes, HttpBody},
boxed::BoxedIntoRoute,
error_handling::{HandleError, HandleErrorLayer},
handler::Handler,
http::{Method, Request, StatusCode},
response::Response,
routing::{future::RouteFuture, Fallback, MethodFilter, Route},
};
use axum_core::response::IntoResponse;
use bytes::BytesMut;
use std::{
convert::Infallible,
fmt,
task::{Context, Poll},
};
use tower::{service_fn, util::MapResponseLayer};
use tower_layer::Layer;
use tower_service::Service;
macro_rules! top_level_service_fn {
(
$name:ident, GET
) => {
top_level_service_fn!(
/// Route `GET` requests to the given service.
///
/// # Example
///
/// ```rust
/// use axum::{
/// http::Request,
/// Router,
/// routing::get_service,
/// };
/// use http::Response;
/// use std::convert::Infallible;
/// use hyper::Body;
///
/// let service = tower::service_fn(|request: Request
| async {
/// Ok::<_, Infallible>(Response::new(Body::empty()))
/// });
///
/// // Requests to `GET /` will go to `service`.
/// let app = Router::new().route("/", get_service(service));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// Note that `get` routes will also be called for `HEAD` requests but will have
/// the response body removed. Make sure to add explicit `HEAD` routes
/// afterwards.
$name,
GET
);
};
(
$name:ident, $method:ident
) => {
top_level_service_fn!(
#[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
///
/// See [`get_service`] for an example.
$name,
$method
);
};
(
$(#[$m:meta])+
$name:ident, $method:ident
) => {
$(#[$m])+
pub fn $name(svc: T) -> MethodRouter
where
T: Service> + Clone + Send + 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
B: HttpBody + Send + 'static,
S: Clone,
{
on_service(MethodFilter::$method, svc)
}
};
}
macro_rules! top_level_handler_fn {
(
$name:ident, GET
) => {
top_level_handler_fn!(
/// Route `GET` requests to the given handler.
///
/// # Example
///
/// ```rust
/// use axum::{
/// routing::get,
/// Router,
/// };
///
/// async fn handler() {}
///
/// // Requests to `GET /` will go to `handler`.
/// let app = Router::new().route("/", get(handler));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// Note that `get` routes will also be called for `HEAD` requests but will have
/// the response body removed. Make sure to add explicit `HEAD` routes
/// afterwards.
$name,
GET
);
};
(
$name:ident, $method:ident
) => {
top_level_handler_fn!(
#[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
///
/// See [`get`] for an example.
$name,
$method
);
};
(
$(#[$m:meta])+
$name:ident, $method:ident
) => {
$(#[$m])+
pub fn $name(handler: H) -> MethodRouter
where
H: Handler,
B: HttpBody + Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
{
on(MethodFilter::$method, handler)
}
};
}
macro_rules! chained_service_fn {
(
$name:ident, GET
) => {
chained_service_fn!(
/// Chain an additional service that will only accept `GET` requests.
///
/// # Example
///
/// ```rust
/// use axum::{
/// http::Request,
/// Router,
/// routing::post_service,
/// };
/// use http::Response;
/// use std::convert::Infallible;
/// use hyper::Body;
///
/// let service = tower::service_fn(|request: Request| async {
/// Ok::<_, Infallible>(Response::new(Body::empty()))
/// });
///
/// let other_service = tower::service_fn(|request: Request| async {
/// Ok::<_, Infallible>(Response::new(Body::empty()))
/// });
///
/// // Requests to `POST /` will go to `service` and `GET /` will go to
/// // `other_service`.
/// let app = Router::new().route("/", post_service(service).get_service(other_service));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// Note that `get` routes will also be called for `HEAD` requests but will have
/// the response body removed. Make sure to add explicit `HEAD` routes
/// afterwards.
$name,
GET
);
};
(
$name:ident, $method:ident
) => {
chained_service_fn!(
#[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
///
/// See [`MethodRouter::get_service`] for an example.
$name,
$method
);
};
(
$(#[$m:meta])+
$name:ident, $method:ident
) => {
$(#[$m])+
#[track_caller]
pub fn $name(self, svc: T) -> Self
where
T: Service, Error = E>
+ Clone
+ Send
+ 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
{
self.on_service(MethodFilter::$method, svc)
}
};
}
macro_rules! chained_handler_fn {
(
$name:ident, GET
) => {
chained_handler_fn!(
/// Chain an additional handler that will only accept `GET` requests.
///
/// # Example
///
/// ```rust
/// use axum::{routing::post, Router};
///
/// async fn handler() {}
///
/// async fn other_handler() {}
///
/// // Requests to `POST /` will go to `handler` and `GET /` will go to
/// // `other_handler`.
/// let app = Router::new().route("/", post(handler).get(other_handler));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// Note that `get` routes will also be called for `HEAD` requests but will have
/// the response body removed. Make sure to add explicit `HEAD` routes
/// afterwards.
$name,
GET
);
};
(
$name:ident, $method:ident
) => {
chained_handler_fn!(
#[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
///
/// See [`MethodRouter::get`] for an example.
$name,
$method
);
};
(
$(#[$m:meta])+
$name:ident, $method:ident
) => {
$(#[$m])+
#[track_caller]
pub fn $name(self, handler: H) -> Self
where
H: Handler,
T: 'static,
S: Send + Sync + 'static,
{
self.on(MethodFilter::$method, handler)
}
};
}
top_level_service_fn!(delete_service, DELETE);
top_level_service_fn!(get_service, GET);
top_level_service_fn!(head_service, HEAD);
top_level_service_fn!(options_service, OPTIONS);
top_level_service_fn!(patch_service, PATCH);
top_level_service_fn!(post_service, POST);
top_level_service_fn!(put_service, PUT);
top_level_service_fn!(trace_service, TRACE);
/// Route requests with the given method to the service.
///
/// # Example
///
/// ```rust
/// use axum::{
/// http::Request,
/// routing::on,
/// Router,
/// routing::{MethodFilter, on_service},
/// };
/// use http::Response;
/// use std::convert::Infallible;
/// use hyper::Body;
///
/// let service = tower::service_fn(|request: Request| async {
/// Ok::<_, Infallible>(Response::new(Body::empty()))
/// });
///
/// // Requests to `POST /` will go to `service`.
/// let app = Router::new().route("/", on_service(MethodFilter::POST, service));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn on_service(filter: MethodFilter, svc: T) -> MethodRouter
where
T: Service> + Clone + Send + 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
B: HttpBody + Send + 'static,
S: Clone,
{
MethodRouter::new().on_service(filter, svc)
}
/// Route requests to the given service regardless of its method.
///
/// # Example
///
/// ```rust
/// use axum::{
/// http::Request,
/// Router,
/// routing::any_service,
/// };
/// use http::Response;
/// use std::convert::Infallible;
/// use hyper::Body;
///
/// let service = tower::service_fn(|request: Request| async {
/// Ok::<_, Infallible>(Response::new(Body::empty()))
/// });
///
/// // All requests to `/` will go to `service`.
/// let app = Router::new().route("/", any_service(service));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// Additional methods can still be chained:
///
/// ```rust
/// use axum::{
/// http::Request,
/// Router,
/// routing::any_service,
/// };
/// use http::Response;
/// use std::convert::Infallible;
/// use hyper::Body;
///
/// let service = tower::service_fn(|request: Request| async {
/// # Ok::<_, Infallible>(Response::new(Body::empty()))
/// // ...
/// });
///
/// let other_service = tower::service_fn(|request: Request| async {
/// # Ok::<_, Infallible>(Response::new(Body::empty()))
/// // ...
/// });
///
/// // `POST /` goes to `other_service`. All other requests go to `service`
/// let app = Router::new().route("/", any_service(service).post_service(other_service));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn any_service(svc: T) -> MethodRouter
where
T: Service> + Clone + Send + 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
B: HttpBody + Send + 'static,
S: Clone,
{
MethodRouter::new()
.fallback_service(svc)
.skip_allow_header()
}
top_level_handler_fn!(delete, DELETE);
top_level_handler_fn!(get, GET);
top_level_handler_fn!(head, HEAD);
top_level_handler_fn!(options, OPTIONS);
top_level_handler_fn!(patch, PATCH);
top_level_handler_fn!(post, POST);
top_level_handler_fn!(put, PUT);
top_level_handler_fn!(trace, TRACE);
/// Route requests with the given method to the handler.
///
/// # Example
///
/// ```rust
/// use axum::{
/// routing::on,
/// Router,
/// routing::MethodFilter,
/// };
///
/// async fn handler() {}
///
/// // Requests to `POST /` will go to `handler`.
/// let app = Router::new().route("/", on(MethodFilter::POST, handler));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn on(filter: MethodFilter, handler: H) -> MethodRouter
where
H: Handler,
B: HttpBody + Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
{
MethodRouter::new().on(filter, handler)
}
/// Route requests with the given handler regardless of the method.
///
/// # Example
///
/// ```rust
/// use axum::{
/// routing::any,
/// Router,
/// };
///
/// async fn handler() {}
///
/// // All requests to `/` will go to `handler`.
/// let app = Router::new().route("/", any(handler));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// Additional methods can still be chained:
///
/// ```rust
/// use axum::{
/// routing::any,
/// Router,
/// };
///
/// async fn handler() {}
///
/// async fn other_handler() {}
///
/// // `POST /` goes to `other_handler`. All other requests go to `handler`
/// let app = Router::new().route("/", any(handler).post(other_handler));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
pub fn any(handler: H) -> MethodRouter
where
H: Handler,
B: HttpBody + Send + 'static,
T: 'static,
S: Clone + Send + Sync + 'static,
{
MethodRouter::new().fallback(handler).skip_allow_header()
}
/// A [`Service`] that accepts requests based on a [`MethodFilter`] and
/// allows chaining additional handlers and services.
///
/// # When does `MethodRouter` implement [`Service`]?
///
/// Whether or not `MethodRouter` implements [`Service`] depends on the state type it requires.
///
/// ```
/// use tower::Service;
/// use axum::{routing::get, extract::State, body::Body, http::Request};
///
/// // this `MethodRouter` doesn't require any state, i.e. the state is `()`,
/// let method_router = get(|| async {});
/// // and thus it implements `Service`
/// assert_service(method_router);
///
/// // this requires a `String` and doesn't implement `Service`
/// let method_router = get(|_: State| async {});
/// // until you provide the `String` with `.with_state(...)`
/// let method_router_with_state = method_router.with_state(String::new());
/// // and then it implements `Service`
/// assert_service(method_router_with_state);
///
/// // helper to check that a value implements `Service`
/// fn assert_service(service: S)
/// where
/// S: Service>,
/// {}
/// ```
#[must_use]
pub struct MethodRouter {
get: MethodEndpoint,
head: MethodEndpoint,
delete: MethodEndpoint,
options: MethodEndpoint,
patch: MethodEndpoint,
post: MethodEndpoint,
put: MethodEndpoint,
trace: MethodEndpoint,
fallback: Fallback,
allow_header: AllowHeader,
}
#[derive(Clone, Debug)]
enum AllowHeader {
/// No `Allow` header value has been built-up yet. This is the default state
None,
/// Don't set an `Allow` header. This is used when `any` or `any_service` are called.
Skip,
/// The current value of the `Allow` header.
Bytes(BytesMut),
}
impl AllowHeader {
fn merge(self, other: Self) -> Self {
match (self, other) {
(AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
(AllowHeader::None, AllowHeader::None) => AllowHeader::None,
(AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
(AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
a.extend_from_slice(b",");
a.extend_from_slice(&b);
AllowHeader::Bytes(a)
}
}
}
}
impl fmt::Debug for MethodRouter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MethodRouter")
.field("get", &self.get)
.field("head", &self.head)
.field("delete", &self.delete)
.field("options", &self.options)
.field("patch", &self.patch)
.field("post", &self.post)
.field("put", &self.put)
.field("trace", &self.trace)
.field("fallback", &self.fallback)
.field("allow_header", &self.allow_header)
.finish()
}
}
impl MethodRouter
where
B: HttpBody + Send + 'static,
S: Clone,
{
/// Chain an additional handler that will accept requests matching the given
/// `MethodFilter`.
///
/// # Example
///
/// ```rust
/// use axum::{
/// routing::get,
/// Router,
/// routing::MethodFilter
/// };
///
/// async fn handler() {}
///
/// async fn other_handler() {}
///
/// // Requests to `GET /` will go to `handler` and `DELETE /` will go to
/// // `other_handler`
/// let app = Router::new().route("/", get(handler).on(MethodFilter::DELETE, other_handler));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
#[track_caller]
pub fn on(self, filter: MethodFilter, handler: H) -> Self
where
H: Handler,
T: 'static,
S: Send + Sync + 'static,
{
self.on_endpoint(
filter,
MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
)
}
chained_handler_fn!(delete, DELETE);
chained_handler_fn!(get, GET);
chained_handler_fn!(head, HEAD);
chained_handler_fn!(options, OPTIONS);
chained_handler_fn!(patch, PATCH);
chained_handler_fn!(post, POST);
chained_handler_fn!(put, PUT);
chained_handler_fn!(trace, TRACE);
/// Add a fallback [`Handler`] to the router.
pub fn fallback(mut self, handler: H) -> Self
where
H: Handler,
T: 'static,
S: Send + Sync + 'static,
{
self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
self
}
}
impl MethodRouter<(), B, Infallible>
where
B: HttpBody + Send + 'static,
{
/// Convert the handler into a [`MakeService`].
///
/// This allows you to serve a single handler if you don't need any routing:
///
/// ```rust
/// use axum::{
/// Server,
/// handler::Handler,
/// http::{Uri, Method},
/// response::IntoResponse,
/// routing::get,
/// };
/// use std::net::SocketAddr;
///
/// async fn handler(method: Method, uri: Uri, body: String) -> String {
/// format!("received `{} {}` with body `{:?}`", method, uri, body)
/// }
///
/// let router = get(handler).post(handler);
///
/// # async {
/// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
/// .serve(router.into_make_service())
/// .await?;
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
///
/// [`MakeService`]: tower::make::MakeService
pub fn into_make_service(self) -> IntoMakeService {
IntoMakeService::new(self.with_state(()))
}
/// Convert the router into a [`MakeService`] which stores information
/// about the incoming connection.
///
/// See [`Router::into_make_service_with_connect_info`] for more details.
///
/// ```rust
/// use axum::{
/// Server,
/// handler::Handler,
/// response::IntoResponse,
/// extract::ConnectInfo,
/// routing::get,
/// };
/// use std::net::SocketAddr;
///
/// async fn handler(ConnectInfo(addr): ConnectInfo) -> String {
/// format!("Hello {}", addr)
/// }
///
/// let router = get(handler).post(handler);
///
/// # async {
/// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
/// .serve(router.into_make_service_with_connect_info::())
/// .await?;
/// # Ok::<_, hyper::Error>(())
/// # };
/// ```
///
/// [`MakeService`]: tower::make::MakeService
/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
#[cfg(feature = "tokio")]
pub fn into_make_service_with_connect_info(self) -> IntoMakeServiceWithConnectInfo {
IntoMakeServiceWithConnectInfo::new(self.with_state(()))
}
}
impl MethodRouter
where
B: HttpBody + Send + 'static,
S: Clone,
{
/// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all
/// requests.
pub fn new() -> Self {
let fallback = Route::new(service_fn(|_: Request| async {
Ok(StatusCode::METHOD_NOT_ALLOWED.into_response())
}));
Self {
get: MethodEndpoint::None,
head: MethodEndpoint::None,
delete: MethodEndpoint::None,
options: MethodEndpoint::None,
patch: MethodEndpoint::None,
post: MethodEndpoint::None,
put: MethodEndpoint::None,
trace: MethodEndpoint::None,
allow_header: AllowHeader::None,
fallback: Fallback::Default(fallback),
}
}
/// Provide the state for the router.
pub fn with_state(self, state: S) -> MethodRouter {
MethodRouter {
get: self.get.with_state(&state),
head: self.head.with_state(&state),
delete: self.delete.with_state(&state),
options: self.options.with_state(&state),
patch: self.patch.with_state(&state),
post: self.post.with_state(&state),
put: self.put.with_state(&state),
trace: self.trace.with_state(&state),
allow_header: self.allow_header,
fallback: self.fallback.with_state(state),
}
}
/// Chain an additional service that will accept requests matching the given
/// `MethodFilter`.
///
/// # Example
///
/// ```rust
/// use axum::{
/// http::Request,
/// Router,
/// routing::{MethodFilter, on_service},
/// };
/// use http::Response;
/// use std::convert::Infallible;
/// use hyper::Body;
///
/// let service = tower::service_fn(|request: Request| async {
/// Ok::<_, Infallible>(Response::new(Body::empty()))
/// });
///
/// // Requests to `DELETE /` will go to `service`
/// let app = Router::new().route("/", on_service(MethodFilter::DELETE, service));
/// # async {
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
#[track_caller]
pub fn on_service(self, filter: MethodFilter, svc: T) -> Self
where
T: Service, Error = E> + Clone + Send + 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
{
self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
}
#[track_caller]
fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint) -> Self {
// written as a separate function to generate less IR
#[track_caller]
fn set_endpoint(
method_name: &str,
out: &mut MethodEndpoint,
endpoint: &MethodEndpoint,
endpoint_filter: MethodFilter,
filter: MethodFilter,
allow_header: &mut AllowHeader,
methods: &[&'static str],
) where
MethodEndpoint: Clone,
S: Clone,
{
if endpoint_filter.contains(filter) {
if out.is_some() {
panic!(
"Overlapping method route. Cannot add two method routes that both handle \
`{method_name}`",
)
}
*out = endpoint.clone();
for method in methods {
append_allow_header(allow_header, method);
}
}
}
set_endpoint(
"GET",
&mut self.get,
&endpoint,
filter,
MethodFilter::GET,
&mut self.allow_header,
&["GET", "HEAD"],
);
set_endpoint(
"HEAD",
&mut self.head,
&endpoint,
filter,
MethodFilter::HEAD,
&mut self.allow_header,
&["HEAD"],
);
set_endpoint(
"TRACE",
&mut self.trace,
&endpoint,
filter,
MethodFilter::TRACE,
&mut self.allow_header,
&["TRACE"],
);
set_endpoint(
"PUT",
&mut self.put,
&endpoint,
filter,
MethodFilter::PUT,
&mut self.allow_header,
&["PUT"],
);
set_endpoint(
"POST",
&mut self.post,
&endpoint,
filter,
MethodFilter::POST,
&mut self.allow_header,
&["POST"],
);
set_endpoint(
"PATCH",
&mut self.patch,
&endpoint,
filter,
MethodFilter::PATCH,
&mut self.allow_header,
&["PATCH"],
);
set_endpoint(
"OPTIONS",
&mut self.options,
&endpoint,
filter,
MethodFilter::OPTIONS,
&mut self.allow_header,
&["OPTIONS"],
);
set_endpoint(
"DELETE",
&mut self.delete,
&endpoint,
filter,
MethodFilter::DELETE,
&mut self.allow_header,
&["DELETE"],
);
self
}
chained_service_fn!(delete_service, DELETE);
chained_service_fn!(get_service, GET);
chained_service_fn!(head_service, HEAD);
chained_service_fn!(options_service, OPTIONS);
chained_service_fn!(patch_service, PATCH);
chained_service_fn!(post_service, POST);
chained_service_fn!(put_service, PUT);
chained_service_fn!(trace_service, TRACE);
#[doc = include_str!("../docs/method_routing/fallback.md")]
pub fn fallback_service(mut self, svc: T) -> Self
where
T: Service, Error = E> + Clone + Send + 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
{
self.fallback = Fallback::Service(Route::new(svc));
self
}
#[doc = include_str!("../docs/method_routing/layer.md")]
pub fn layer(self, layer: L) -> MethodRouter
where
L: Layer> + Clone + Send + 'static,
L::Service: Service> + Clone + Send + 'static,
>>::Response: IntoResponse + 'static,
>>::Error: Into + 'static,
>>::Future: Send + 'static,
E: 'static,
S: 'static,
NewReqBody: HttpBody + 'static,
NewError: 'static,
{
let layer_fn = move |route: Route| route.layer(layer.clone());
MethodRouter {
get: self.get.map(layer_fn.clone()),
head: self.head.map(layer_fn.clone()),
delete: self.delete.map(layer_fn.clone()),
options: self.options.map(layer_fn.clone()),
patch: self.patch.map(layer_fn.clone()),
post: self.post.map(layer_fn.clone()),
put: self.put.map(layer_fn.clone()),
trace: self.trace.map(layer_fn.clone()),
fallback: self.fallback.map(layer_fn),
allow_header: self.allow_header,
}
}
#[doc = include_str!("../docs/method_routing/route_layer.md")]
#[track_caller]
pub fn route_layer(mut self, layer: L) -> MethodRouter
where
L: Layer> + Clone + Send + 'static,
L::Service: Service, Error = E> + Clone + Send + 'static,
>>::Response: IntoResponse + 'static,
>>::Future: Send + 'static,
E: 'static,
S: 'static,
{
if self.get.is_none()
&& self.head.is_none()
&& self.delete.is_none()
&& self.options.is_none()
&& self.patch.is_none()
&& self.post.is_none()
&& self.put.is_none()
&& self.trace.is_none()
{
panic!(
"Adding a route_layer before any routes is a no-op. \
Add the routes you want the layer to apply to first."
);
}
let layer_fn = move |svc| {
let svc = layer.layer(svc);
let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
Route::new(svc)
};
self.get = self.get.map(layer_fn.clone());
self.head = self.head.map(layer_fn.clone());
self.delete = self.delete.map(layer_fn.clone());
self.options = self.options.map(layer_fn.clone());
self.patch = self.patch.map(layer_fn.clone());
self.post = self.post.map(layer_fn.clone());
self.put = self.put.map(layer_fn.clone());
self.trace = self.trace.map(layer_fn);
self
}
#[track_caller]
pub(crate) fn merge_for_path(
mut self,
path: Option<&str>,
other: MethodRouter,
) -> Self {
// written using inner functions to generate less IR
#[track_caller]
fn merge_inner(
path: Option<&str>,
name: &str,
first: MethodEndpoint,
second: MethodEndpoint,
) -> MethodEndpoint {
match (first, second) {
(MethodEndpoint::None, MethodEndpoint::None) => MethodEndpoint::None,
(pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => pick,
_ => {
if let Some(path) = path {
panic!(
"Overlapping method route. Handler for `{name} {path}` already exists"
);
} else {
panic!(
"Overlapping method route. Cannot merge two method routes that both \
define `{name}`"
);
}
}
}
}
self.get = merge_inner(path, "GET", self.get, other.get);
self.head = merge_inner(path, "HEAD", self.head, other.head);
self.delete = merge_inner(path, "DELETE", self.delete, other.delete);
self.options = merge_inner(path, "OPTIONS", self.options, other.options);
self.patch = merge_inner(path, "PATCH", self.patch, other.patch);
self.post = merge_inner(path, "POST", self.post, other.post);
self.put = merge_inner(path, "PUT", self.put, other.put);
self.trace = merge_inner(path, "TRACE", self.trace, other.trace);
self.fallback = self
.fallback
.merge(other.fallback)
.expect("Cannot merge two `MethodRouter`s that both have a fallback");
self.allow_header = self.allow_header.merge(other.allow_header);
self
}
#[doc = include_str!("../docs/method_routing/merge.md")]
#[track_caller]
pub fn merge(self, other: MethodRouter) -> Self {
self.merge_for_path(None, other)
}
/// Apply a [`HandleErrorLayer`].
///
/// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`.
pub fn handle_error(self, f: F) -> MethodRouter
where
F: Clone + Send + Sync + 'static,
HandleError, F, T>: Service, Error = Infallible>,
, F, T> as Service>>::Future: Send,
, F, T> as Service>>::Response: IntoResponse + Send,
T: 'static,
E: 'static,
B: 'static,
S: 'static,
{
self.layer(HandleErrorLayer::new(f))
}
fn skip_allow_header(mut self) -> Self {
self.allow_header = AllowHeader::Skip;
self
}
pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture {
macro_rules! call {
(
$req:expr,
$method:expr,
$method_variant:ident,
$svc:expr
) => {
if $method == Method::$method_variant {
match $svc {
MethodEndpoint::None => {}
MethodEndpoint::Route(route) => {
return RouteFuture::from_future(route.oneshot_inner($req))
.strip_body($method == Method::HEAD);
}
MethodEndpoint::BoxedHandler(handler) => {
let mut route = handler.clone().into_route(state);
return RouteFuture::from_future(route.oneshot_inner($req))
.strip_body($method == Method::HEAD);
}
}
}
};
}
let method = req.method().clone();
// written with a pattern match like this to ensure we call all routes
let Self {
get,
head,
delete,
options,
patch,
post,
put,
trace,
fallback,
allow_header,
} = self;
call!(req, method, HEAD, head);
call!(req, method, HEAD, get);
call!(req, method, GET, get);
call!(req, method, POST, post);
call!(req, method, OPTIONS, options);
call!(req, method, PATCH, patch);
call!(req, method, PUT, put);
call!(req, method, DELETE, delete);
call!(req, method, TRACE, trace);
let future = fallback.call_with_state(req, state);
match allow_header {
AllowHeader::None => future.allow_header(Bytes::new()),
AllowHeader::Skip => future,
AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
}
}
}
fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
match allow_header {
AllowHeader::None => {
*allow_header = AllowHeader::Bytes(BytesMut::from(method));
}
AllowHeader::Skip => {}
AllowHeader::Bytes(allow_header) => {
if let Ok(s) = std::str::from_utf8(allow_header) {
if !s.contains(method) {
allow_header.extend_from_slice(b",");
allow_header.extend_from_slice(method.as_bytes());
}
} else {
#[cfg(debug_assertions)]
panic!("`allow_header` contained invalid uft-8. This should never happen")
}
}
}
}
impl Clone for MethodRouter {
fn clone(&self) -> Self {
Self {
get: self.get.clone(),
head: self.head.clone(),
delete: self.delete.clone(),
options: self.options.clone(),
patch: self.patch.clone(),
post: self.post.clone(),
put: self.put.clone(),
trace: self.trace.clone(),
fallback: self.fallback.clone(),
allow_header: self.allow_header.clone(),
}
}
}
impl Default for MethodRouter
where
B: HttpBody + Send + 'static,
S: Clone,
{
fn default() -> Self {
Self::new()
}
}
enum MethodEndpoint {
None,
Route(Route),
BoxedHandler(BoxedIntoRoute),
}
impl MethodEndpoint
where
S: Clone,
{
fn is_some(&self) -> bool {
matches!(self, Self::Route(_) | Self::BoxedHandler(_))
}
fn is_none(&self) -> bool {
matches!(self, Self::None)
}
fn map(self, f: F) -> MethodEndpoint
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route) -> Route + Clone + Send + 'static,
B2: HttpBody + 'static,
E2: 'static,
{
match self {
Self::None => MethodEndpoint::None,
Self::Route(route) => MethodEndpoint::Route(f(route)),
Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
}
}
fn with_state(self, state: &S) -> MethodEndpoint {
match self {
MethodEndpoint::None => MethodEndpoint::None,
MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
MethodEndpoint::BoxedHandler(handler) => {
MethodEndpoint::Route(handler.into_route(state.clone()))
}
}
}
}
impl Clone for MethodEndpoint {
fn clone(&self) -> Self {
match self {
Self::None => Self::None,
Self::Route(inner) => Self::Route(inner.clone()),
Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
}
}
}
impl fmt::Debug for MethodEndpoint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => f.debug_tuple("None").finish(),
Self::Route(inner) => inner.fmt(f),
Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
}
}
}
impl Service> for MethodRouter<(), B, E>
where
B: HttpBody + Send + 'static,
{
type Response = Response;
type Error = E;
type Future = RouteFuture;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> {
Poll::Ready(Ok(()))
}
#[inline]
fn call(&mut self, req: Request) -> Self::Future {
self.call_with_state(req, ())
}
}
impl Handler<(), S, B> for MethodRouter
where
S: Clone + 'static,
B: HttpBody + Send + 'static,
{
type Future = InfallibleRouteFuture;
fn call(mut self, req: Request, state: S) -> Self::Future {
InfallibleRouteFuture::new(self.call_with_state(req, state))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
body::Body, error_handling::HandleErrorLayer, extract::State,
handler::HandlerWithoutStateExt,
};
use axum_core::response::IntoResponse;
use http::{header::ALLOW, HeaderMap};
use std::time::Duration;
use tower::{timeout::TimeoutLayer, Service, ServiceBuilder, ServiceExt};
use tower_http::{services::fs::ServeDir, validate_request::ValidateRequestHeaderLayer};
#[crate::test]
async fn method_not_allowed_by_default() {
let mut svc = MethodRouter::new();
let (status, _, body) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert!(body.is_empty());
}
#[crate::test]
async fn get_service_fn() {
async fn handle(_req: Request) -> Result, Infallible> {
Ok(Response::new(Body::from("ok")))
}
let mut svc = get_service(service_fn(handle));
let (status, _, body) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "ok");
}
#[crate::test]
async fn get_handler() {
let mut svc = MethodRouter::new().get(ok);
let (status, _, body) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "ok");
}
#[crate::test]
async fn get_accepts_head() {
let mut svc = MethodRouter::new().get(ok);
let (status, _, body) = call(Method::HEAD, &mut svc).await;
assert_eq!(status, StatusCode::OK);
assert!(body.is_empty());
}
#[crate::test]
async fn head_takes_precedence_over_get() {
let mut svc = MethodRouter::new().head(created).get(ok);
let (status, _, body) = call(Method::HEAD, &mut svc).await;
assert_eq!(status, StatusCode::CREATED);
assert!(body.is_empty());
}
#[crate::test]
async fn merge() {
let mut svc = get(ok).merge(post(ok));
let (status, _, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
let (status, _, _) = call(Method::POST, &mut svc).await;
assert_eq!(status, StatusCode::OK);
}
#[crate::test]
async fn layer() {
let mut svc = MethodRouter::new()
.get(|| async { std::future::pending::<()>().await })
.layer(ValidateRequestHeaderLayer::bearer("password"));
// method with route
let (status, _, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
// method without route
let (status, _, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
}
#[crate::test]
async fn route_layer() {
let mut svc = MethodRouter::new()
.get(|| async { std::future::pending::<()>().await })
.route_layer(ValidateRequestHeaderLayer::bearer("password"));
// method with route
let (status, _, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
// method without route
let (status, _, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
}
#[allow(dead_code)]
fn buiding_complex_router() {
let app = crate::Router::new().route(
"/",
// use the all the things 💣️
get(ok)
.post(ok)
.route_layer(ValidateRequestHeaderLayer::bearer("password"))
.merge(delete_service(ServeDir::new(".")))
.fallback(|| async { StatusCode::NOT_FOUND })
.put(ok)
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_| async {
StatusCode::REQUEST_TIMEOUT
}))
.layer(TimeoutLayer::new(Duration::from_secs(10))),
),
);
crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service());
}
#[crate::test]
async fn sets_allow_header() {
let mut svc = MethodRouter::new().put(ok).patch(ok);
let (status, headers, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "PUT,PATCH");
}
#[crate::test]
async fn sets_allow_header_get_head() {
let mut svc = MethodRouter::new().get(ok).head(ok);
let (status, headers, _) = call(Method::PUT, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "GET,HEAD");
}
#[crate::test]
async fn empty_allow_header_by_default() {
let mut svc = MethodRouter::new();
let (status, headers, _) = call(Method::PATCH, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "");
}
#[crate::test]
async fn allow_header_when_merging() {
let a = put(ok).patch(ok);
let b = get(ok).head(ok);
let mut svc = a.merge(b);
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
}
#[crate::test]
async fn allow_header_any() {
let mut svc = any(ok);
let (status, headers, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
assert!(!headers.contains_key(ALLOW));
}
#[crate::test]
async fn allow_header_with_fallback() {
let mut svc = MethodRouter::new()
.get(ok)
.fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") });
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "GET,HEAD");
}
#[crate::test]
async fn allow_header_with_fallback_that_sets_allow() {
async fn fallback(method: Method) -> Response {
if method == Method::POST {
"OK".into_response()
} else {
(
StatusCode::METHOD_NOT_ALLOWED,
[(ALLOW, "GET,POST")],
"Method not allowed",
)
.into_response()
}
}
let mut svc = MethodRouter::new().get(ok).fallback(fallback);
let (status, _, _) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
let (status, _, _) = call(Method::POST, &mut svc).await;
assert_eq!(status, StatusCode::OK);
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "GET,POST");
}
#[crate::test]
async fn allow_header_noop_middleware() {
let mut svc = MethodRouter::new()
.get(ok)
.layer(tower::layer::util::Identity::new());
let (status, headers, _) = call(Method::DELETE, &mut svc).await;
assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
assert_eq!(headers[ALLOW], "GET,HEAD");
}
#[crate::test]
#[should_panic(
expected = "Overlapping method route. Cannot add two method routes that both handle `GET`"
)]
async fn handler_overlaps() {
let _: MethodRouter<()> = get(ok).get(ok);
}
#[crate::test]
#[should_panic(
expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
)]
async fn service_overlaps() {
let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
}
#[crate::test]
async fn get_head_does_not_overlap() {
let _: MethodRouter<()> = get(ok).head(ok);
}
#[crate::test]
async fn head_get_does_not_overlap() {
let _: MethodRouter<()> = head(ok).get(ok);
}
#[crate::test]
async fn accessing_state() {
let mut svc = MethodRouter::new()
.get(|State(state): State<&'static str>| async move { state })
.with_state("state");
let (status, _, text) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(text, "state");
}
#[crate::test]
async fn fallback_accessing_state() {
let mut svc = MethodRouter::new()
.fallback(|State(state): State<&'static str>| async move { state })
.with_state("state");
let (status, _, text) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(text, "state");
}
#[crate::test]
async fn merge_accessing_state() {
let one = get(|State(state): State<&'static str>| async move { state });
let two = post(|State(state): State<&'static str>| async move { state });
let mut svc = one.merge(two).with_state("state");
let (status, _, text) = call(Method::GET, &mut svc).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(text, "state");
let (status, _, _) = call(Method::POST, &mut svc).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(text, "state");
}
async fn call(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
where
S: Service, Error = Infallible>,
S::Response: IntoResponse,
{
let request = Request::builder()
.uri("/")
.method(method)
.body(Body::empty())
.unwrap();
let response = svc
.ready()
.await
.unwrap()
.call(request)
.await
.unwrap()
.into_response();
let (parts, body) = response.into_parts();
let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap();
(parts.status, parts.headers, body)
}
async fn ok() -> (StatusCode, &'static str) {
(StatusCode::OK, "ok")
}
async fn created() -> (StatusCode, &'static str) {
(StatusCode::CREATED, "created")
}
}