1 //! Route to services and handlers based on HTTP methods.
2 
3 use super::{future::InfallibleRouteFuture, IntoMakeService};
4 #[cfg(feature = "tokio")]
5 use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6 use crate::{
7     body::{Body, Bytes, HttpBody},
8     boxed::BoxedIntoRoute,
9     error_handling::{HandleError, HandleErrorLayer},
10     handler::Handler,
11     http::{Method, Request, StatusCode},
12     response::Response,
13     routing::{future::RouteFuture, Fallback, MethodFilter, Route},
14 };
15 use axum_core::response::IntoResponse;
16 use bytes::BytesMut;
17 use std::{
18     convert::Infallible,
19     fmt,
20     task::{Context, Poll},
21 };
22 use tower::{service_fn, util::MapResponseLayer};
23 use tower_layer::Layer;
24 use tower_service::Service;
25 
26 macro_rules! top_level_service_fn {
27     (
28         $name:ident, GET
29     ) => {
30         top_level_service_fn!(
31             /// Route `GET` requests to the given service.
32             ///
33             /// # Example
34             ///
35             /// ```rust
36             /// use axum::{
37             ///     http::Request,
38             ///     Router,
39             ///     routing::get_service,
40             /// };
41             /// use http::Response;
42             /// use std::convert::Infallible;
43             /// use hyper::Body;
44             ///
45             /// let service = tower::service_fn(|request: Request<Body>| async {
46             ///     Ok::<_, Infallible>(Response::new(Body::empty()))
47             /// });
48             ///
49             /// // Requests to `GET /` will go to `service`.
50             /// let app = Router::new().route("/", get_service(service));
51             /// # async {
52             /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
53             /// # };
54             /// ```
55             ///
56             /// Note that `get` routes will also be called for `HEAD` requests but will have
57             /// the response body removed. Make sure to add explicit `HEAD` routes
58             /// afterwards.
59             $name,
60             GET
61         );
62     };
63 
64     (
65         $name:ident, $method:ident
66     ) => {
67         top_level_service_fn!(
68             #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
69             ///
70             /// See [`get_service`] for an example.
71             $name,
72             $method
73         );
74     };
75 
76     (
77         $(#[$m:meta])+
78         $name:ident, $method:ident
79     ) => {
80         $(#[$m])+
81         pub fn $name<T, S, B>(svc: T) -> MethodRouter<S, B, T::Error>
82         where
83             T: Service<Request<B>> + Clone + Send + 'static,
84             T::Response: IntoResponse + 'static,
85             T::Future: Send + 'static,
86             B: HttpBody + Send + 'static,
87             S: Clone,
88         {
89             on_service(MethodFilter::$method, svc)
90         }
91     };
92 }
93 
94 macro_rules! top_level_handler_fn {
95     (
96         $name:ident, GET
97     ) => {
98         top_level_handler_fn!(
99             /// Route `GET` requests to the given handler.
100             ///
101             /// # Example
102             ///
103             /// ```rust
104             /// use axum::{
105             ///     routing::get,
106             ///     Router,
107             /// };
108             ///
109             /// async fn handler() {}
110             ///
111             /// // Requests to `GET /` will go to `handler`.
112             /// let app = Router::new().route("/", get(handler));
113             /// # async {
114             /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
115             /// # };
116             /// ```
117             ///
118             /// Note that `get` routes will also be called for `HEAD` requests but will have
119             /// the response body removed. Make sure to add explicit `HEAD` routes
120             /// afterwards.
121             $name,
122             GET
123         );
124     };
125 
126     (
127         $name:ident, $method:ident
128     ) => {
129         top_level_handler_fn!(
130             #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
131             ///
132             /// See [`get`] for an example.
133             $name,
134             $method
135         );
136     };
137 
138     (
139         $(#[$m:meta])+
140         $name:ident, $method:ident
141     ) => {
142         $(#[$m])+
143         pub fn $name<H, T, S, B>(handler: H) -> MethodRouter<S, B, Infallible>
144         where
145             H: Handler<T, S, B>,
146             B: HttpBody + Send + 'static,
147             T: 'static,
148             S: Clone + Send + Sync + 'static,
149         {
150             on(MethodFilter::$method, handler)
151         }
152     };
153 }
154 
155 macro_rules! chained_service_fn {
156     (
157         $name:ident, GET
158     ) => {
159         chained_service_fn!(
160             /// Chain an additional service that will only accept `GET` requests.
161             ///
162             /// # Example
163             ///
164             /// ```rust
165             /// use axum::{
166             ///     http::Request,
167             ///     Router,
168             ///     routing::post_service,
169             /// };
170             /// use http::Response;
171             /// use std::convert::Infallible;
172             /// use hyper::Body;
173             ///
174             /// let service = tower::service_fn(|request: Request<Body>| async {
175             ///     Ok::<_, Infallible>(Response::new(Body::empty()))
176             /// });
177             ///
178             /// let other_service = tower::service_fn(|request: Request<Body>| async {
179             ///     Ok::<_, Infallible>(Response::new(Body::empty()))
180             /// });
181             ///
182             /// // Requests to `POST /` will go to `service` and `GET /` will go to
183             /// // `other_service`.
184             /// let app = Router::new().route("/", post_service(service).get_service(other_service));
185             /// # async {
186             /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
187             /// # };
188             /// ```
189             ///
190             /// Note that `get` routes will also be called for `HEAD` requests but will have
191             /// the response body removed. Make sure to add explicit `HEAD` routes
192             /// afterwards.
193             $name,
194             GET
195         );
196     };
197 
198     (
199         $name:ident, $method:ident
200     ) => {
201         chained_service_fn!(
202             #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
203             ///
204             /// See [`MethodRouter::get_service`] for an example.
205             $name,
206             $method
207         );
208     };
209 
210     (
211         $(#[$m:meta])+
212         $name:ident, $method:ident
213     ) => {
214         $(#[$m])+
215         #[track_caller]
216         pub fn $name<T>(self, svc: T) -> Self
217         where
218             T: Service<Request<B>, Error = E>
219                 + Clone
220                 + Send
221                 + 'static,
222             T::Response: IntoResponse + 'static,
223             T::Future: Send + 'static,
224         {
225             self.on_service(MethodFilter::$method, svc)
226         }
227     };
228 }
229 
230 macro_rules! chained_handler_fn {
231     (
232         $name:ident, GET
233     ) => {
234         chained_handler_fn!(
235             /// Chain an additional handler that will only accept `GET` requests.
236             ///
237             /// # Example
238             ///
239             /// ```rust
240             /// use axum::{routing::post, Router};
241             ///
242             /// async fn handler() {}
243             ///
244             /// async fn other_handler() {}
245             ///
246             /// // Requests to `POST /` will go to `handler` and `GET /` will go to
247             /// // `other_handler`.
248             /// let app = Router::new().route("/", post(handler).get(other_handler));
249             /// # async {
250             /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
251             /// # };
252             /// ```
253             ///
254             /// Note that `get` routes will also be called for `HEAD` requests but will have
255             /// the response body removed. Make sure to add explicit `HEAD` routes
256             /// afterwards.
257             $name,
258             GET
259         );
260     };
261 
262     (
263         $name:ident, $method:ident
264     ) => {
265         chained_handler_fn!(
266             #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
267             ///
268             /// See [`MethodRouter::get`] for an example.
269             $name,
270             $method
271         );
272     };
273 
274     (
275         $(#[$m:meta])+
276         $name:ident, $method:ident
277     ) => {
278         $(#[$m])+
279         #[track_caller]
280         pub fn $name<H, T>(self, handler: H) -> Self
281         where
282             H: Handler<T, S, B>,
283             T: 'static,
284             S: Send + Sync + 'static,
285         {
286             self.on(MethodFilter::$method, handler)
287         }
288     };
289 }
290 
291 top_level_service_fn!(delete_service, DELETE);
292 top_level_service_fn!(get_service, GET);
293 top_level_service_fn!(head_service, HEAD);
294 top_level_service_fn!(options_service, OPTIONS);
295 top_level_service_fn!(patch_service, PATCH);
296 top_level_service_fn!(post_service, POST);
297 top_level_service_fn!(put_service, PUT);
298 top_level_service_fn!(trace_service, TRACE);
299 
300 /// Route requests with the given method to the service.
301 ///
302 /// # Example
303 ///
304 /// ```rust
305 /// use axum::{
306 ///     http::Request,
307 ///     routing::on,
308 ///     Router,
309 ///     routing::{MethodFilter, on_service},
310 /// };
311 /// use http::Response;
312 /// use std::convert::Infallible;
313 /// use hyper::Body;
314 ///
315 /// let service = tower::service_fn(|request: Request<Body>| async {
316 ///     Ok::<_, Infallible>(Response::new(Body::empty()))
317 /// });
318 ///
319 /// // Requests to `POST /` will go to `service`.
320 /// let app = Router::new().route("/", on_service(MethodFilter::POST, service));
321 /// # async {
322 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
323 /// # };
324 /// ```
on_service<T, S, B>(filter: MethodFilter, svc: T) -> MethodRouter<S, B, T::Error> where T: Service<Request<B>> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, B: HttpBody + Send + 'static, S: Clone,325 pub fn on_service<T, S, B>(filter: MethodFilter, svc: T) -> MethodRouter<S, B, T::Error>
326 where
327     T: Service<Request<B>> + Clone + Send + 'static,
328     T::Response: IntoResponse + 'static,
329     T::Future: Send + 'static,
330     B: HttpBody + Send + 'static,
331     S: Clone,
332 {
333     MethodRouter::new().on_service(filter, svc)
334 }
335 
336 /// Route requests to the given service regardless of its method.
337 ///
338 /// # Example
339 ///
340 /// ```rust
341 /// use axum::{
342 ///     http::Request,
343 ///     Router,
344 ///     routing::any_service,
345 /// };
346 /// use http::Response;
347 /// use std::convert::Infallible;
348 /// use hyper::Body;
349 ///
350 /// let service = tower::service_fn(|request: Request<Body>| async {
351 ///     Ok::<_, Infallible>(Response::new(Body::empty()))
352 /// });
353 ///
354 /// // All requests to `/` will go to `service`.
355 /// let app = Router::new().route("/", any_service(service));
356 /// # async {
357 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
358 /// # };
359 /// ```
360 ///
361 /// Additional methods can still be chained:
362 ///
363 /// ```rust
364 /// use axum::{
365 ///     http::Request,
366 ///     Router,
367 ///     routing::any_service,
368 /// };
369 /// use http::Response;
370 /// use std::convert::Infallible;
371 /// use hyper::Body;
372 ///
373 /// let service = tower::service_fn(|request: Request<Body>| async {
374 ///     # Ok::<_, Infallible>(Response::new(Body::empty()))
375 ///     // ...
376 /// });
377 ///
378 /// let other_service = tower::service_fn(|request: Request<Body>| async {
379 ///     # Ok::<_, Infallible>(Response::new(Body::empty()))
380 ///     // ...
381 /// });
382 ///
383 /// // `POST /` goes to `other_service`. All other requests go to `service`
384 /// let app = Router::new().route("/", any_service(service).post_service(other_service));
385 /// # async {
386 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
387 /// # };
388 /// ```
any_service<T, S, B>(svc: T) -> MethodRouter<S, B, T::Error> where T: Service<Request<B>> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, B: HttpBody + Send + 'static, S: Clone,389 pub fn any_service<T, S, B>(svc: T) -> MethodRouter<S, B, T::Error>
390 where
391     T: Service<Request<B>> + Clone + Send + 'static,
392     T::Response: IntoResponse + 'static,
393     T::Future: Send + 'static,
394     B: HttpBody + Send + 'static,
395     S: Clone,
396 {
397     MethodRouter::new()
398         .fallback_service(svc)
399         .skip_allow_header()
400 }
401 
402 top_level_handler_fn!(delete, DELETE);
403 top_level_handler_fn!(get, GET);
404 top_level_handler_fn!(head, HEAD);
405 top_level_handler_fn!(options, OPTIONS);
406 top_level_handler_fn!(patch, PATCH);
407 top_level_handler_fn!(post, POST);
408 top_level_handler_fn!(put, PUT);
409 top_level_handler_fn!(trace, TRACE);
410 
411 /// Route requests with the given method to the handler.
412 ///
413 /// # Example
414 ///
415 /// ```rust
416 /// use axum::{
417 ///     routing::on,
418 ///     Router,
419 ///     routing::MethodFilter,
420 /// };
421 ///
422 /// async fn handler() {}
423 ///
424 /// // Requests to `POST /` will go to `handler`.
425 /// let app = Router::new().route("/", on(MethodFilter::POST, handler));
426 /// # async {
427 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
428 /// # };
429 /// ```
on<H, T, S, B>(filter: MethodFilter, handler: H) -> MethodRouter<S, B, Infallible> where H: Handler<T, S, B>, B: HttpBody + Send + 'static, T: 'static, S: Clone + Send + Sync + 'static,430 pub fn on<H, T, S, B>(filter: MethodFilter, handler: H) -> MethodRouter<S, B, Infallible>
431 where
432     H: Handler<T, S, B>,
433     B: HttpBody + Send + 'static,
434     T: 'static,
435     S: Clone + Send + Sync + 'static,
436 {
437     MethodRouter::new().on(filter, handler)
438 }
439 
440 /// Route requests with the given handler regardless of the method.
441 ///
442 /// # Example
443 ///
444 /// ```rust
445 /// use axum::{
446 ///     routing::any,
447 ///     Router,
448 /// };
449 ///
450 /// async fn handler() {}
451 ///
452 /// // All requests to `/` will go to `handler`.
453 /// let app = Router::new().route("/", any(handler));
454 /// # async {
455 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
456 /// # };
457 /// ```
458 ///
459 /// Additional methods can still be chained:
460 ///
461 /// ```rust
462 /// use axum::{
463 ///     routing::any,
464 ///     Router,
465 /// };
466 ///
467 /// async fn handler() {}
468 ///
469 /// async fn other_handler() {}
470 ///
471 /// // `POST /` goes to `other_handler`. All other requests go to `handler`
472 /// let app = Router::new().route("/", any(handler).post(other_handler));
473 /// # async {
474 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
475 /// # };
476 /// ```
any<H, T, S, B>(handler: H) -> MethodRouter<S, B, Infallible> where H: Handler<T, S, B>, B: HttpBody + Send + 'static, T: 'static, S: Clone + Send + Sync + 'static,477 pub fn any<H, T, S, B>(handler: H) -> MethodRouter<S, B, Infallible>
478 where
479     H: Handler<T, S, B>,
480     B: HttpBody + Send + 'static,
481     T: 'static,
482     S: Clone + Send + Sync + 'static,
483 {
484     MethodRouter::new().fallback(handler).skip_allow_header()
485 }
486 
487 /// A [`Service`] that accepts requests based on a [`MethodFilter`] and
488 /// allows chaining additional handlers and services.
489 ///
490 /// # When does `MethodRouter` implement [`Service`]?
491 ///
492 /// Whether or not `MethodRouter` implements [`Service`] depends on the state type it requires.
493 ///
494 /// ```
495 /// use tower::Service;
496 /// use axum::{routing::get, extract::State, body::Body, http::Request};
497 ///
498 /// // this `MethodRouter` doesn't require any state, i.e. the state is `()`,
499 /// let method_router = get(|| async {});
500 /// // and thus it implements `Service`
501 /// assert_service(method_router);
502 ///
503 /// // this requires a `String` and doesn't implement `Service`
504 /// let method_router = get(|_: State<String>| async {});
505 /// // until you provide the `String` with `.with_state(...)`
506 /// let method_router_with_state = method_router.with_state(String::new());
507 /// // and then it implements `Service`
508 /// assert_service(method_router_with_state);
509 ///
510 /// // helper to check that a value implements `Service`
511 /// fn assert_service<S>(service: S)
512 /// where
513 ///     S: Service<Request<Body>>,
514 /// {}
515 /// ```
516 #[must_use]
517 pub struct MethodRouter<S = (), B = Body, E = Infallible> {
518     get: MethodEndpoint<S, B, E>,
519     head: MethodEndpoint<S, B, E>,
520     delete: MethodEndpoint<S, B, E>,
521     options: MethodEndpoint<S, B, E>,
522     patch: MethodEndpoint<S, B, E>,
523     post: MethodEndpoint<S, B, E>,
524     put: MethodEndpoint<S, B, E>,
525     trace: MethodEndpoint<S, B, E>,
526     fallback: Fallback<S, B, E>,
527     allow_header: AllowHeader,
528 }
529 
530 #[derive(Clone, Debug)]
531 enum AllowHeader {
532     /// No `Allow` header value has been built-up yet. This is the default state
533     None,
534     /// Don't set an `Allow` header. This is used when `any` or `any_service` are called.
535     Skip,
536     /// The current value of the `Allow` header.
537     Bytes(BytesMut),
538 }
539 
540 impl AllowHeader {
merge(self, other: Self) -> Self541     fn merge(self, other: Self) -> Self {
542         match (self, other) {
543             (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
544             (AllowHeader::None, AllowHeader::None) => AllowHeader::None,
545             (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
546             (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
547             (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
548                 a.extend_from_slice(b",");
549                 a.extend_from_slice(&b);
550                 AllowHeader::Bytes(a)
551             }
552         }
553     }
554 }
555 
556 impl<S, B, E> fmt::Debug for MethodRouter<S, B, E> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result557     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
558         f.debug_struct("MethodRouter")
559             .field("get", &self.get)
560             .field("head", &self.head)
561             .field("delete", &self.delete)
562             .field("options", &self.options)
563             .field("patch", &self.patch)
564             .field("post", &self.post)
565             .field("put", &self.put)
566             .field("trace", &self.trace)
567             .field("fallback", &self.fallback)
568             .field("allow_header", &self.allow_header)
569             .finish()
570     }
571 }
572 
573 impl<S, B> MethodRouter<S, B, Infallible>
574 where
575     B: HttpBody + Send + 'static,
576     S: Clone,
577 {
578     /// Chain an additional handler that will accept requests matching the given
579     /// `MethodFilter`.
580     ///
581     /// # Example
582     ///
583     /// ```rust
584     /// use axum::{
585     ///     routing::get,
586     ///     Router,
587     ///     routing::MethodFilter
588     /// };
589     ///
590     /// async fn handler() {}
591     ///
592     /// async fn other_handler() {}
593     ///
594     /// // Requests to `GET /` will go to `handler` and `DELETE /` will go to
595     /// // `other_handler`
596     /// let app = Router::new().route("/", get(handler).on(MethodFilter::DELETE, other_handler));
597     /// # async {
598     /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
599     /// # };
600     /// ```
601     #[track_caller]
on<H, T>(self, filter: MethodFilter, handler: H) -> Self where H: Handler<T, S, B>, T: 'static, S: Send + Sync + 'static,602     pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
603     where
604         H: Handler<T, S, B>,
605         T: 'static,
606         S: Send + Sync + 'static,
607     {
608         self.on_endpoint(
609             filter,
610             MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
611         )
612     }
613 
614     chained_handler_fn!(delete, DELETE);
615     chained_handler_fn!(get, GET);
616     chained_handler_fn!(head, HEAD);
617     chained_handler_fn!(options, OPTIONS);
618     chained_handler_fn!(patch, PATCH);
619     chained_handler_fn!(post, POST);
620     chained_handler_fn!(put, PUT);
621     chained_handler_fn!(trace, TRACE);
622 
623     /// Add a fallback [`Handler`] to the router.
fallback<H, T>(mut self, handler: H) -> Self where H: Handler<T, S, B>, T: 'static, S: Send + Sync + 'static,624     pub fn fallback<H, T>(mut self, handler: H) -> Self
625     where
626         H: Handler<T, S, B>,
627         T: 'static,
628         S: Send + Sync + 'static,
629     {
630         self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
631         self
632     }
633 }
634 
635 impl<B> MethodRouter<(), B, Infallible>
636 where
637     B: HttpBody + Send + 'static,
638 {
639     /// Convert the handler into a [`MakeService`].
640     ///
641     /// This allows you to serve a single handler if you don't need any routing:
642     ///
643     /// ```rust
644     /// use axum::{
645     ///     Server,
646     ///     handler::Handler,
647     ///     http::{Uri, Method},
648     ///     response::IntoResponse,
649     ///     routing::get,
650     /// };
651     /// use std::net::SocketAddr;
652     ///
653     /// async fn handler(method: Method, uri: Uri, body: String) -> String {
654     ///     format!("received `{} {}` with body `{:?}`", method, uri, body)
655     /// }
656     ///
657     /// let router = get(handler).post(handler);
658     ///
659     /// # async {
660     /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
661     ///     .serve(router.into_make_service())
662     ///     .await?;
663     /// # Ok::<_, hyper::Error>(())
664     /// # };
665     /// ```
666     ///
667     /// [`MakeService`]: tower::make::MakeService
into_make_service(self) -> IntoMakeService<Self>668     pub fn into_make_service(self) -> IntoMakeService<Self> {
669         IntoMakeService::new(self.with_state(()))
670     }
671 
672     /// Convert the router into a [`MakeService`] which stores information
673     /// about the incoming connection.
674     ///
675     /// See [`Router::into_make_service_with_connect_info`] for more details.
676     ///
677     /// ```rust
678     /// use axum::{
679     ///     Server,
680     ///     handler::Handler,
681     ///     response::IntoResponse,
682     ///     extract::ConnectInfo,
683     ///     routing::get,
684     /// };
685     /// use std::net::SocketAddr;
686     ///
687     /// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
688     ///     format!("Hello {}", addr)
689     /// }
690     ///
691     /// let router = get(handler).post(handler);
692     ///
693     /// # async {
694     /// Server::bind(&SocketAddr::from(([127, 0, 0, 1], 3000)))
695     ///     .serve(router.into_make_service_with_connect_info::<SocketAddr>())
696     ///     .await?;
697     /// # Ok::<_, hyper::Error>(())
698     /// # };
699     /// ```
700     ///
701     /// [`MakeService`]: tower::make::MakeService
702     /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
703     #[cfg(feature = "tokio")]
into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C>704     pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
705         IntoMakeServiceWithConnectInfo::new(self.with_state(()))
706     }
707 }
708 
709 impl<S, B, E> MethodRouter<S, B, E>
710 where
711     B: HttpBody + Send + 'static,
712     S: Clone,
713 {
714     /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all
715     /// requests.
new() -> Self716     pub fn new() -> Self {
717         let fallback = Route::new(service_fn(|_: Request<B>| async {
718             Ok(StatusCode::METHOD_NOT_ALLOWED.into_response())
719         }));
720 
721         Self {
722             get: MethodEndpoint::None,
723             head: MethodEndpoint::None,
724             delete: MethodEndpoint::None,
725             options: MethodEndpoint::None,
726             patch: MethodEndpoint::None,
727             post: MethodEndpoint::None,
728             put: MethodEndpoint::None,
729             trace: MethodEndpoint::None,
730             allow_header: AllowHeader::None,
731             fallback: Fallback::Default(fallback),
732         }
733     }
734 
735     /// Provide the state for the router.
with_state<S2>(self, state: S) -> MethodRouter<S2, B, E>736     pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, B, E> {
737         MethodRouter {
738             get: self.get.with_state(&state),
739             head: self.head.with_state(&state),
740             delete: self.delete.with_state(&state),
741             options: self.options.with_state(&state),
742             patch: self.patch.with_state(&state),
743             post: self.post.with_state(&state),
744             put: self.put.with_state(&state),
745             trace: self.trace.with_state(&state),
746             allow_header: self.allow_header,
747             fallback: self.fallback.with_state(state),
748         }
749     }
750 
751     /// Chain an additional service that will accept requests matching the given
752     /// `MethodFilter`.
753     ///
754     /// # Example
755     ///
756     /// ```rust
757     /// use axum::{
758     ///     http::Request,
759     ///     Router,
760     ///     routing::{MethodFilter, on_service},
761     /// };
762     /// use http::Response;
763     /// use std::convert::Infallible;
764     /// use hyper::Body;
765     ///
766     /// let service = tower::service_fn(|request: Request<Body>| async {
767     ///     Ok::<_, Infallible>(Response::new(Body::empty()))
768     /// });
769     ///
770     /// // Requests to `DELETE /` will go to `service`
771     /// let app = Router::new().route("/", on_service(MethodFilter::DELETE, service));
772     /// # async {
773     /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
774     /// # };
775     /// ```
776     #[track_caller]
on_service<T>(self, filter: MethodFilter, svc: T) -> Self where T: Service<Request<B>, Error = E> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static,777     pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
778     where
779         T: Service<Request<B>, Error = E> + Clone + Send + 'static,
780         T::Response: IntoResponse + 'static,
781         T::Future: Send + 'static,
782     {
783         self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
784     }
785 
786     #[track_caller]
on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, B, E>) -> Self787     fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, B, E>) -> Self {
788         // written as a separate function to generate less IR
789         #[track_caller]
790         fn set_endpoint<S, B, E>(
791             method_name: &str,
792             out: &mut MethodEndpoint<S, B, E>,
793             endpoint: &MethodEndpoint<S, B, E>,
794             endpoint_filter: MethodFilter,
795             filter: MethodFilter,
796             allow_header: &mut AllowHeader,
797             methods: &[&'static str],
798         ) where
799             MethodEndpoint<S, B, E>: Clone,
800             S: Clone,
801         {
802             if endpoint_filter.contains(filter) {
803                 if out.is_some() {
804                     panic!(
805                         "Overlapping method route. Cannot add two method routes that both handle \
806                          `{method_name}`",
807                     )
808                 }
809                 *out = endpoint.clone();
810                 for method in methods {
811                     append_allow_header(allow_header, method);
812                 }
813             }
814         }
815 
816         set_endpoint(
817             "GET",
818             &mut self.get,
819             &endpoint,
820             filter,
821             MethodFilter::GET,
822             &mut self.allow_header,
823             &["GET", "HEAD"],
824         );
825 
826         set_endpoint(
827             "HEAD",
828             &mut self.head,
829             &endpoint,
830             filter,
831             MethodFilter::HEAD,
832             &mut self.allow_header,
833             &["HEAD"],
834         );
835 
836         set_endpoint(
837             "TRACE",
838             &mut self.trace,
839             &endpoint,
840             filter,
841             MethodFilter::TRACE,
842             &mut self.allow_header,
843             &["TRACE"],
844         );
845 
846         set_endpoint(
847             "PUT",
848             &mut self.put,
849             &endpoint,
850             filter,
851             MethodFilter::PUT,
852             &mut self.allow_header,
853             &["PUT"],
854         );
855 
856         set_endpoint(
857             "POST",
858             &mut self.post,
859             &endpoint,
860             filter,
861             MethodFilter::POST,
862             &mut self.allow_header,
863             &["POST"],
864         );
865 
866         set_endpoint(
867             "PATCH",
868             &mut self.patch,
869             &endpoint,
870             filter,
871             MethodFilter::PATCH,
872             &mut self.allow_header,
873             &["PATCH"],
874         );
875 
876         set_endpoint(
877             "OPTIONS",
878             &mut self.options,
879             &endpoint,
880             filter,
881             MethodFilter::OPTIONS,
882             &mut self.allow_header,
883             &["OPTIONS"],
884         );
885 
886         set_endpoint(
887             "DELETE",
888             &mut self.delete,
889             &endpoint,
890             filter,
891             MethodFilter::DELETE,
892             &mut self.allow_header,
893             &["DELETE"],
894         );
895 
896         self
897     }
898 
899     chained_service_fn!(delete_service, DELETE);
900     chained_service_fn!(get_service, GET);
901     chained_service_fn!(head_service, HEAD);
902     chained_service_fn!(options_service, OPTIONS);
903     chained_service_fn!(patch_service, PATCH);
904     chained_service_fn!(post_service, POST);
905     chained_service_fn!(put_service, PUT);
906     chained_service_fn!(trace_service, TRACE);
907 
908     #[doc = include_str!("../docs/method_routing/fallback.md")]
fallback_service<T>(mut self, svc: T) -> Self where T: Service<Request<B>, Error = E> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static,909     pub fn fallback_service<T>(mut self, svc: T) -> Self
910     where
911         T: Service<Request<B>, Error = E> + Clone + Send + 'static,
912         T::Response: IntoResponse + 'static,
913         T::Future: Send + 'static,
914     {
915         self.fallback = Fallback::Service(Route::new(svc));
916         self
917     }
918 
919     #[doc = include_str!("../docs/method_routing/layer.md")]
layer<L, NewReqBody, NewError>(self, layer: L) -> MethodRouter<S, NewReqBody, NewError> where L: Layer<Route<B, E>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, E: 'static, S: 'static, NewReqBody: HttpBody + 'static, NewError: 'static,920     pub fn layer<L, NewReqBody, NewError>(self, layer: L) -> MethodRouter<S, NewReqBody, NewError>
921     where
922         L: Layer<Route<B, E>> + Clone + Send + 'static,
923         L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
924         <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
925         <L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static,
926         <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
927         E: 'static,
928         S: 'static,
929         NewReqBody: HttpBody + 'static,
930         NewError: 'static,
931     {
932         let layer_fn = move |route: Route<B, E>| route.layer(layer.clone());
933 
934         MethodRouter {
935             get: self.get.map(layer_fn.clone()),
936             head: self.head.map(layer_fn.clone()),
937             delete: self.delete.map(layer_fn.clone()),
938             options: self.options.map(layer_fn.clone()),
939             patch: self.patch.map(layer_fn.clone()),
940             post: self.post.map(layer_fn.clone()),
941             put: self.put.map(layer_fn.clone()),
942             trace: self.trace.map(layer_fn.clone()),
943             fallback: self.fallback.map(layer_fn),
944             allow_header: self.allow_header,
945         }
946     }
947 
948     #[doc = include_str!("../docs/method_routing/route_layer.md")]
949     #[track_caller]
route_layer<L>(mut self, layer: L) -> MethodRouter<S, B, E> where L: Layer<Route<B, E>> + Clone + Send + 'static, L::Service: Service<Request<B>, Error = E> + Clone + Send + 'static, <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<B>>>::Future: Send + 'static, E: 'static, S: 'static,950     pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, B, E>
951     where
952         L: Layer<Route<B, E>> + Clone + Send + 'static,
953         L::Service: Service<Request<B>, Error = E> + Clone + Send + 'static,
954         <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
955         <L::Service as Service<Request<B>>>::Future: Send + 'static,
956         E: 'static,
957         S: 'static,
958     {
959         if self.get.is_none()
960             && self.head.is_none()
961             && self.delete.is_none()
962             && self.options.is_none()
963             && self.patch.is_none()
964             && self.post.is_none()
965             && self.put.is_none()
966             && self.trace.is_none()
967         {
968             panic!(
969                 "Adding a route_layer before any routes is a no-op. \
970                  Add the routes you want the layer to apply to first."
971             );
972         }
973 
974         let layer_fn = move |svc| {
975             let svc = layer.layer(svc);
976             let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
977             Route::new(svc)
978         };
979 
980         self.get = self.get.map(layer_fn.clone());
981         self.head = self.head.map(layer_fn.clone());
982         self.delete = self.delete.map(layer_fn.clone());
983         self.options = self.options.map(layer_fn.clone());
984         self.patch = self.patch.map(layer_fn.clone());
985         self.post = self.post.map(layer_fn.clone());
986         self.put = self.put.map(layer_fn.clone());
987         self.trace = self.trace.map(layer_fn);
988 
989         self
990     }
991 
992     #[track_caller]
merge_for_path( mut self, path: Option<&str>, other: MethodRouter<S, B, E>, ) -> Self993     pub(crate) fn merge_for_path(
994         mut self,
995         path: Option<&str>,
996         other: MethodRouter<S, B, E>,
997     ) -> Self {
998         // written using inner functions to generate less IR
999         #[track_caller]
1000         fn merge_inner<S, B, E>(
1001             path: Option<&str>,
1002             name: &str,
1003             first: MethodEndpoint<S, B, E>,
1004             second: MethodEndpoint<S, B, E>,
1005         ) -> MethodEndpoint<S, B, E> {
1006             match (first, second) {
1007                 (MethodEndpoint::None, MethodEndpoint::None) => MethodEndpoint::None,
1008                 (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => pick,
1009                 _ => {
1010                     if let Some(path) = path {
1011                         panic!(
1012                             "Overlapping method route. Handler for `{name} {path}` already exists"
1013                         );
1014                     } else {
1015                         panic!(
1016                             "Overlapping method route. Cannot merge two method routes that both \
1017                              define `{name}`"
1018                         );
1019                     }
1020                 }
1021             }
1022         }
1023 
1024         self.get = merge_inner(path, "GET", self.get, other.get);
1025         self.head = merge_inner(path, "HEAD", self.head, other.head);
1026         self.delete = merge_inner(path, "DELETE", self.delete, other.delete);
1027         self.options = merge_inner(path, "OPTIONS", self.options, other.options);
1028         self.patch = merge_inner(path, "PATCH", self.patch, other.patch);
1029         self.post = merge_inner(path, "POST", self.post, other.post);
1030         self.put = merge_inner(path, "PUT", self.put, other.put);
1031         self.trace = merge_inner(path, "TRACE", self.trace, other.trace);
1032 
1033         self.fallback = self
1034             .fallback
1035             .merge(other.fallback)
1036             .expect("Cannot merge two `MethodRouter`s that both have a fallback");
1037 
1038         self.allow_header = self.allow_header.merge(other.allow_header);
1039 
1040         self
1041     }
1042 
1043     #[doc = include_str!("../docs/method_routing/merge.md")]
1044     #[track_caller]
merge(self, other: MethodRouter<S, B, E>) -> Self1045     pub fn merge(self, other: MethodRouter<S, B, E>) -> Self {
1046         self.merge_for_path(None, other)
1047     }
1048 
1049     /// Apply a [`HandleErrorLayer`].
1050     ///
1051     /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`.
handle_error<F, T>(self, f: F) -> MethodRouter<S, B, Infallible> where F: Clone + Send + Sync + 'static, HandleError<Route<B, E>, F, T>: Service<Request<B>, Error = Infallible>, <HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Future: Send, <HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send, T: 'static, E: 'static, B: 'static, S: 'static,1052     pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, B, Infallible>
1053     where
1054         F: Clone + Send + Sync + 'static,
1055         HandleError<Route<B, E>, F, T>: Service<Request<B>, Error = Infallible>,
1056         <HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Future: Send,
1057         <HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send,
1058         T: 'static,
1059         E: 'static,
1060         B: 'static,
1061         S: 'static,
1062     {
1063         self.layer(HandleErrorLayer::new(f))
1064     }
1065 
skip_allow_header(mut self) -> Self1066     fn skip_allow_header(mut self) -> Self {
1067         self.allow_header = AllowHeader::Skip;
1068         self
1069     }
1070 
call_with_state(&mut self, req: Request<B>, state: S) -> RouteFuture<B, E>1071     pub(crate) fn call_with_state(&mut self, req: Request<B>, state: S) -> RouteFuture<B, E> {
1072         macro_rules! call {
1073             (
1074                 $req:expr,
1075                 $method:expr,
1076                 $method_variant:ident,
1077                 $svc:expr
1078             ) => {
1079                 if $method == Method::$method_variant {
1080                     match $svc {
1081                         MethodEndpoint::None => {}
1082                         MethodEndpoint::Route(route) => {
1083                             return RouteFuture::from_future(route.oneshot_inner($req))
1084                                 .strip_body($method == Method::HEAD);
1085                         }
1086                         MethodEndpoint::BoxedHandler(handler) => {
1087                             let mut route = handler.clone().into_route(state);
1088                             return RouteFuture::from_future(route.oneshot_inner($req))
1089                                 .strip_body($method == Method::HEAD);
1090                         }
1091                     }
1092                 }
1093             };
1094         }
1095 
1096         let method = req.method().clone();
1097 
1098         // written with a pattern match like this to ensure we call all routes
1099         let Self {
1100             get,
1101             head,
1102             delete,
1103             options,
1104             patch,
1105             post,
1106             put,
1107             trace,
1108             fallback,
1109             allow_header,
1110         } = self;
1111 
1112         call!(req, method, HEAD, head);
1113         call!(req, method, HEAD, get);
1114         call!(req, method, GET, get);
1115         call!(req, method, POST, post);
1116         call!(req, method, OPTIONS, options);
1117         call!(req, method, PATCH, patch);
1118         call!(req, method, PUT, put);
1119         call!(req, method, DELETE, delete);
1120         call!(req, method, TRACE, trace);
1121 
1122         let future = fallback.call_with_state(req, state);
1123 
1124         match allow_header {
1125             AllowHeader::None => future.allow_header(Bytes::new()),
1126             AllowHeader::Skip => future,
1127             AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
1128         }
1129     }
1130 }
1131 
append_allow_header(allow_header: &mut AllowHeader, method: &'static str)1132 fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
1133     match allow_header {
1134         AllowHeader::None => {
1135             *allow_header = AllowHeader::Bytes(BytesMut::from(method));
1136         }
1137         AllowHeader::Skip => {}
1138         AllowHeader::Bytes(allow_header) => {
1139             if let Ok(s) = std::str::from_utf8(allow_header) {
1140                 if !s.contains(method) {
1141                     allow_header.extend_from_slice(b",");
1142                     allow_header.extend_from_slice(method.as_bytes());
1143                 }
1144             } else {
1145                 #[cfg(debug_assertions)]
1146                 panic!("`allow_header` contained invalid uft-8. This should never happen")
1147             }
1148         }
1149     }
1150 }
1151 
1152 impl<S, B, E> Clone for MethodRouter<S, B, E> {
clone(&self) -> Self1153     fn clone(&self) -> Self {
1154         Self {
1155             get: self.get.clone(),
1156             head: self.head.clone(),
1157             delete: self.delete.clone(),
1158             options: self.options.clone(),
1159             patch: self.patch.clone(),
1160             post: self.post.clone(),
1161             put: self.put.clone(),
1162             trace: self.trace.clone(),
1163             fallback: self.fallback.clone(),
1164             allow_header: self.allow_header.clone(),
1165         }
1166     }
1167 }
1168 
1169 impl<S, B, E> Default for MethodRouter<S, B, E>
1170 where
1171     B: HttpBody + Send + 'static,
1172     S: Clone,
1173 {
default() -> Self1174     fn default() -> Self {
1175         Self::new()
1176     }
1177 }
1178 
1179 enum MethodEndpoint<S, B, E> {
1180     None,
1181     Route(Route<B, E>),
1182     BoxedHandler(BoxedIntoRoute<S, B, E>),
1183 }
1184 
1185 impl<S, B, E> MethodEndpoint<S, B, E>
1186 where
1187     S: Clone,
1188 {
is_some(&self) -> bool1189     fn is_some(&self) -> bool {
1190         matches!(self, Self::Route(_) | Self::BoxedHandler(_))
1191     }
1192 
is_none(&self) -> bool1193     fn is_none(&self) -> bool {
1194         matches!(self, Self::None)
1195     }
1196 
map<F, B2, E2>(self, f: F) -> MethodEndpoint<S, B2, E2> where S: 'static, B: 'static, E: 'static, F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static, B2: HttpBody + 'static, E2: 'static,1197     fn map<F, B2, E2>(self, f: F) -> MethodEndpoint<S, B2, E2>
1198     where
1199         S: 'static,
1200         B: 'static,
1201         E: 'static,
1202         F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
1203         B2: HttpBody + 'static,
1204         E2: 'static,
1205     {
1206         match self {
1207             Self::None => MethodEndpoint::None,
1208             Self::Route(route) => MethodEndpoint::Route(f(route)),
1209             Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
1210         }
1211     }
1212 
with_state<S2>(self, state: &S) -> MethodEndpoint<S2, B, E>1213     fn with_state<S2>(self, state: &S) -> MethodEndpoint<S2, B, E> {
1214         match self {
1215             MethodEndpoint::None => MethodEndpoint::None,
1216             MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
1217             MethodEndpoint::BoxedHandler(handler) => {
1218                 MethodEndpoint::Route(handler.into_route(state.clone()))
1219             }
1220         }
1221     }
1222 }
1223 
1224 impl<S, B, E> Clone for MethodEndpoint<S, B, E> {
clone(&self) -> Self1225     fn clone(&self) -> Self {
1226         match self {
1227             Self::None => Self::None,
1228             Self::Route(inner) => Self::Route(inner.clone()),
1229             Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
1230         }
1231     }
1232 }
1233 
1234 impl<S, B, E> fmt::Debug for MethodEndpoint<S, B, E> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result1235     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1236         match self {
1237             Self::None => f.debug_tuple("None").finish(),
1238             Self::Route(inner) => inner.fmt(f),
1239             Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
1240         }
1241     }
1242 }
1243 
1244 impl<B, E> Service<Request<B>> for MethodRouter<(), B, E>
1245 where
1246     B: HttpBody + Send + 'static,
1247 {
1248     type Response = Response;
1249     type Error = E;
1250     type Future = RouteFuture<B, E>;
1251 
1252     #[inline]
poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>1253     fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1254         Poll::Ready(Ok(()))
1255     }
1256 
1257     #[inline]
call(&mut self, req: Request<B>) -> Self::Future1258     fn call(&mut self, req: Request<B>) -> Self::Future {
1259         self.call_with_state(req, ())
1260     }
1261 }
1262 
1263 impl<S, B> Handler<(), S, B> for MethodRouter<S, B>
1264 where
1265     S: Clone + 'static,
1266     B: HttpBody + Send + 'static,
1267 {
1268     type Future = InfallibleRouteFuture<B>;
1269 
call(mut self, req: Request<B>, state: S) -> Self::Future1270     fn call(mut self, req: Request<B>, state: S) -> Self::Future {
1271         InfallibleRouteFuture::new(self.call_with_state(req, state))
1272     }
1273 }
1274 
1275 #[cfg(test)]
1276 mod tests {
1277     use super::*;
1278     use crate::{
1279         body::Body, error_handling::HandleErrorLayer, extract::State,
1280         handler::HandlerWithoutStateExt,
1281     };
1282     use axum_core::response::IntoResponse;
1283     use http::{header::ALLOW, HeaderMap};
1284     use std::time::Duration;
1285     use tower::{timeout::TimeoutLayer, Service, ServiceBuilder, ServiceExt};
1286     use tower_http::{services::fs::ServeDir, validate_request::ValidateRequestHeaderLayer};
1287 
1288     #[crate::test]
method_not_allowed_by_default()1289     async fn method_not_allowed_by_default() {
1290         let mut svc = MethodRouter::new();
1291         let (status, _, body) = call(Method::GET, &mut svc).await;
1292         assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1293         assert!(body.is_empty());
1294     }
1295 
1296     #[crate::test]
get_service_fn()1297     async fn get_service_fn() {
1298         async fn handle(_req: Request<Body>) -> Result<Response<Body>, Infallible> {
1299             Ok(Response::new(Body::from("ok")))
1300         }
1301 
1302         let mut svc = get_service(service_fn(handle));
1303 
1304         let (status, _, body) = call(Method::GET, &mut svc).await;
1305         assert_eq!(status, StatusCode::OK);
1306         assert_eq!(body, "ok");
1307     }
1308 
1309     #[crate::test]
get_handler()1310     async fn get_handler() {
1311         let mut svc = MethodRouter::new().get(ok);
1312         let (status, _, body) = call(Method::GET, &mut svc).await;
1313         assert_eq!(status, StatusCode::OK);
1314         assert_eq!(body, "ok");
1315     }
1316 
1317     #[crate::test]
get_accepts_head()1318     async fn get_accepts_head() {
1319         let mut svc = MethodRouter::new().get(ok);
1320         let (status, _, body) = call(Method::HEAD, &mut svc).await;
1321         assert_eq!(status, StatusCode::OK);
1322         assert!(body.is_empty());
1323     }
1324 
1325     #[crate::test]
head_takes_precedence_over_get()1326     async fn head_takes_precedence_over_get() {
1327         let mut svc = MethodRouter::new().head(created).get(ok);
1328         let (status, _, body) = call(Method::HEAD, &mut svc).await;
1329         assert_eq!(status, StatusCode::CREATED);
1330         assert!(body.is_empty());
1331     }
1332 
1333     #[crate::test]
merge()1334     async fn merge() {
1335         let mut svc = get(ok).merge(post(ok));
1336 
1337         let (status, _, _) = call(Method::GET, &mut svc).await;
1338         assert_eq!(status, StatusCode::OK);
1339 
1340         let (status, _, _) = call(Method::POST, &mut svc).await;
1341         assert_eq!(status, StatusCode::OK);
1342     }
1343 
1344     #[crate::test]
layer()1345     async fn layer() {
1346         let mut svc = MethodRouter::new()
1347             .get(|| async { std::future::pending::<()>().await })
1348             .layer(ValidateRequestHeaderLayer::bearer("password"));
1349 
1350         // method with route
1351         let (status, _, _) = call(Method::GET, &mut svc).await;
1352         assert_eq!(status, StatusCode::UNAUTHORIZED);
1353 
1354         // method without route
1355         let (status, _, _) = call(Method::DELETE, &mut svc).await;
1356         assert_eq!(status, StatusCode::UNAUTHORIZED);
1357     }
1358 
1359     #[crate::test]
route_layer()1360     async fn route_layer() {
1361         let mut svc = MethodRouter::new()
1362             .get(|| async { std::future::pending::<()>().await })
1363             .route_layer(ValidateRequestHeaderLayer::bearer("password"));
1364 
1365         // method with route
1366         let (status, _, _) = call(Method::GET, &mut svc).await;
1367         assert_eq!(status, StatusCode::UNAUTHORIZED);
1368 
1369         // method without route
1370         let (status, _, _) = call(Method::DELETE, &mut svc).await;
1371         assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1372     }
1373 
1374     #[allow(dead_code)]
buiding_complex_router()1375     fn buiding_complex_router() {
1376         let app = crate::Router::new().route(
1377             "/",
1378             // use the all the things ��️
1379             get(ok)
1380                 .post(ok)
1381                 .route_layer(ValidateRequestHeaderLayer::bearer("password"))
1382                 .merge(delete_service(ServeDir::new(".")))
1383                 .fallback(|| async { StatusCode::NOT_FOUND })
1384                 .put(ok)
1385                 .layer(
1386                     ServiceBuilder::new()
1387                         .layer(HandleErrorLayer::new(|_| async {
1388                             StatusCode::REQUEST_TIMEOUT
1389                         }))
1390                         .layer(TimeoutLayer::new(Duration::from_secs(10))),
1391                 ),
1392         );
1393 
1394         crate::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(app.into_make_service());
1395     }
1396 
1397     #[crate::test]
sets_allow_header()1398     async fn sets_allow_header() {
1399         let mut svc = MethodRouter::new().put(ok).patch(ok);
1400         let (status, headers, _) = call(Method::GET, &mut svc).await;
1401         assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1402         assert_eq!(headers[ALLOW], "PUT,PATCH");
1403     }
1404 
1405     #[crate::test]
sets_allow_header_get_head()1406     async fn sets_allow_header_get_head() {
1407         let mut svc = MethodRouter::new().get(ok).head(ok);
1408         let (status, headers, _) = call(Method::PUT, &mut svc).await;
1409         assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1410         assert_eq!(headers[ALLOW], "GET,HEAD");
1411     }
1412 
1413     #[crate::test]
empty_allow_header_by_default()1414     async fn empty_allow_header_by_default() {
1415         let mut svc = MethodRouter::new();
1416         let (status, headers, _) = call(Method::PATCH, &mut svc).await;
1417         assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1418         assert_eq!(headers[ALLOW], "");
1419     }
1420 
1421     #[crate::test]
allow_header_when_merging()1422     async fn allow_header_when_merging() {
1423         let a = put(ok).patch(ok);
1424         let b = get(ok).head(ok);
1425         let mut svc = a.merge(b);
1426 
1427         let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1428         assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1429         assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
1430     }
1431 
1432     #[crate::test]
allow_header_any()1433     async fn allow_header_any() {
1434         let mut svc = any(ok);
1435 
1436         let (status, headers, _) = call(Method::GET, &mut svc).await;
1437         assert_eq!(status, StatusCode::OK);
1438         assert!(!headers.contains_key(ALLOW));
1439     }
1440 
1441     #[crate::test]
allow_header_with_fallback()1442     async fn allow_header_with_fallback() {
1443         let mut svc = MethodRouter::new()
1444             .get(ok)
1445             .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") });
1446 
1447         let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1448         assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1449         assert_eq!(headers[ALLOW], "GET,HEAD");
1450     }
1451 
1452     #[crate::test]
allow_header_with_fallback_that_sets_allow()1453     async fn allow_header_with_fallback_that_sets_allow() {
1454         async fn fallback(method: Method) -> Response {
1455             if method == Method::POST {
1456                 "OK".into_response()
1457             } else {
1458                 (
1459                     StatusCode::METHOD_NOT_ALLOWED,
1460                     [(ALLOW, "GET,POST")],
1461                     "Method not allowed",
1462                 )
1463                     .into_response()
1464             }
1465         }
1466 
1467         let mut svc = MethodRouter::new().get(ok).fallback(fallback);
1468 
1469         let (status, _, _) = call(Method::GET, &mut svc).await;
1470         assert_eq!(status, StatusCode::OK);
1471 
1472         let (status, _, _) = call(Method::POST, &mut svc).await;
1473         assert_eq!(status, StatusCode::OK);
1474 
1475         let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1476         assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1477         assert_eq!(headers[ALLOW], "GET,POST");
1478     }
1479 
1480     #[crate::test]
allow_header_noop_middleware()1481     async fn allow_header_noop_middleware() {
1482         let mut svc = MethodRouter::new()
1483             .get(ok)
1484             .layer(tower::layer::util::Identity::new());
1485 
1486         let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1487         assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1488         assert_eq!(headers[ALLOW], "GET,HEAD");
1489     }
1490 
1491     #[crate::test]
1492     #[should_panic(
1493         expected = "Overlapping method route. Cannot add two method routes that both handle `GET`"
1494     )]
handler_overlaps()1495     async fn handler_overlaps() {
1496         let _: MethodRouter<()> = get(ok).get(ok);
1497     }
1498 
1499     #[crate::test]
1500     #[should_panic(
1501         expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
1502     )]
service_overlaps()1503     async fn service_overlaps() {
1504         let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
1505     }
1506 
1507     #[crate::test]
get_head_does_not_overlap()1508     async fn get_head_does_not_overlap() {
1509         let _: MethodRouter<()> = get(ok).head(ok);
1510     }
1511 
1512     #[crate::test]
head_get_does_not_overlap()1513     async fn head_get_does_not_overlap() {
1514         let _: MethodRouter<()> = head(ok).get(ok);
1515     }
1516 
1517     #[crate::test]
accessing_state()1518     async fn accessing_state() {
1519         let mut svc = MethodRouter::new()
1520             .get(|State(state): State<&'static str>| async move { state })
1521             .with_state("state");
1522 
1523         let (status, _, text) = call(Method::GET, &mut svc).await;
1524 
1525         assert_eq!(status, StatusCode::OK);
1526         assert_eq!(text, "state");
1527     }
1528 
1529     #[crate::test]
fallback_accessing_state()1530     async fn fallback_accessing_state() {
1531         let mut svc = MethodRouter::new()
1532             .fallback(|State(state): State<&'static str>| async move { state })
1533             .with_state("state");
1534 
1535         let (status, _, text) = call(Method::GET, &mut svc).await;
1536 
1537         assert_eq!(status, StatusCode::OK);
1538         assert_eq!(text, "state");
1539     }
1540 
1541     #[crate::test]
merge_accessing_state()1542     async fn merge_accessing_state() {
1543         let one = get(|State(state): State<&'static str>| async move { state });
1544         let two = post(|State(state): State<&'static str>| async move { state });
1545 
1546         let mut svc = one.merge(two).with_state("state");
1547 
1548         let (status, _, text) = call(Method::GET, &mut svc).await;
1549         assert_eq!(status, StatusCode::OK);
1550         assert_eq!(text, "state");
1551 
1552         let (status, _, _) = call(Method::POST, &mut svc).await;
1553         assert_eq!(status, StatusCode::OK);
1554         assert_eq!(text, "state");
1555     }
1556 
call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String) where S: Service<Request<Body>, Error = Infallible>, S::Response: IntoResponse,1557     async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
1558     where
1559         S: Service<Request<Body>, Error = Infallible>,
1560         S::Response: IntoResponse,
1561     {
1562         let request = Request::builder()
1563             .uri("/")
1564             .method(method)
1565             .body(Body::empty())
1566             .unwrap();
1567         let response = svc
1568             .ready()
1569             .await
1570             .unwrap()
1571             .call(request)
1572             .await
1573             .unwrap()
1574             .into_response();
1575         let (parts, body) = response.into_parts();
1576         let body = String::from_utf8(hyper::body::to_bytes(body).await.unwrap().to_vec()).unwrap();
1577         (parts.status, parts.headers, body)
1578     }
1579 
ok() -> (StatusCode, &'static str)1580     async fn ok() -> (StatusCode, &'static str) {
1581         (StatusCode::OK, "ok")
1582     }
1583 
created() -> (StatusCode, &'static str)1584     async fn created() -> (StatusCode, &'static str) {
1585         (StatusCode::CREATED, "created")
1586     }
1587 }
1588