1 use crate::{
2     body::{boxed, Body, Empty, HttpBody},
3     response::Response,
4 };
5 use axum_core::response::IntoResponse;
6 use bytes::Bytes;
7 use http::{
8     header::{self, CONTENT_LENGTH},
9     HeaderMap, HeaderValue, Request,
10 };
11 use pin_project_lite::pin_project;
12 use std::{
13     convert::Infallible,
14     fmt,
15     future::Future,
16     pin::Pin,
17     task::{Context, Poll},
18 };
19 use tower::{
20     util::{BoxCloneService, MapResponseLayer, Oneshot},
21     ServiceBuilder, ServiceExt,
22 };
23 use tower_layer::Layer;
24 use tower_service::Service;
25 
26 /// How routes are stored inside a [`Router`](super::Router).
27 ///
28 /// You normally shouldn't need to care about this type. It's used in
29 /// [`Router::layer`](super::Router::layer).
30 pub struct Route<B = Body, E = Infallible>(BoxCloneService<Request<B>, Response, E>);
31 
32 impl<B, E> Route<B, E> {
new<T>(svc: T) -> Self where T: Service<Request<B>, Error = E> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static,33     pub(crate) fn new<T>(svc: T) -> Self
34     where
35         T: Service<Request<B>, Error = E> + Clone + Send + 'static,
36         T::Response: IntoResponse + 'static,
37         T::Future: Send + 'static,
38     {
39         Self(BoxCloneService::new(
40             svc.map_response(IntoResponse::into_response),
41         ))
42     }
43 
oneshot_inner( &mut self, req: Request<B>, ) -> Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>>44     pub(crate) fn oneshot_inner(
45         &mut self,
46         req: Request<B>,
47     ) -> Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>> {
48         self.0.clone().oneshot(req)
49     }
50 
layer<L, NewReqBody, NewError>(self, layer: L) -> Route<NewReqBody, NewError> where L: Layer<Route<B, E>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, NewReqBody: 'static, NewError: 'static,51     pub(crate) fn layer<L, NewReqBody, NewError>(self, layer: L) -> Route<NewReqBody, NewError>
52     where
53         L: Layer<Route<B, E>> + Clone + Send + 'static,
54         L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
55         <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
56         <L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static,
57         <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
58         NewReqBody: 'static,
59         NewError: 'static,
60     {
61         let layer = ServiceBuilder::new()
62             .map_err(Into::into)
63             .layer(MapResponseLayer::new(IntoResponse::into_response))
64             .layer(layer)
65             .into_inner();
66 
67         Route::new(layer.layer(self))
68     }
69 }
70 
71 impl<B, E> Clone for Route<B, E> {
clone(&self) -> Self72     fn clone(&self) -> Self {
73         Self(self.0.clone())
74     }
75 }
76 
77 impl<B, E> fmt::Debug for Route<B, E> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result78     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79         f.debug_struct("Route").finish()
80     }
81 }
82 
83 impl<B, E> Service<Request<B>> for Route<B, E>
84 where
85     B: HttpBody,
86 {
87     type Response = Response;
88     type Error = E;
89     type Future = RouteFuture<B, E>;
90 
91     #[inline]
poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>92     fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
93         Poll::Ready(Ok(()))
94     }
95 
96     #[inline]
call(&mut self, req: Request<B>) -> Self::Future97     fn call(&mut self, req: Request<B>) -> Self::Future {
98         RouteFuture::from_future(self.oneshot_inner(req))
99     }
100 }
101 
102 pin_project! {
103     /// Response future for [`Route`].
104     pub struct RouteFuture<B, E> {
105         #[pin]
106         kind: RouteFutureKind<B, E>,
107         strip_body: bool,
108         allow_header: Option<Bytes>,
109     }
110 }
111 
112 pin_project! {
113     #[project = RouteFutureKindProj]
114     enum RouteFutureKind<B, E> {
115         Future {
116             #[pin]
117             future: Oneshot<
118                 BoxCloneService<Request<B>, Response, E>,
119                 Request<B>,
120             >,
121         },
122         Response {
123             response: Option<Response>,
124         }
125     }
126 }
127 
128 impl<B, E> RouteFuture<B, E> {
from_future( future: Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>>, ) -> Self129     pub(crate) fn from_future(
130         future: Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>>,
131     ) -> Self {
132         Self {
133             kind: RouteFutureKind::Future { future },
134             strip_body: false,
135             allow_header: None,
136         }
137     }
138 
strip_body(mut self, strip_body: bool) -> Self139     pub(crate) fn strip_body(mut self, strip_body: bool) -> Self {
140         self.strip_body = strip_body;
141         self
142     }
143 
allow_header(mut self, allow_header: Bytes) -> Self144     pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
145         self.allow_header = Some(allow_header);
146         self
147     }
148 }
149 
150 impl<B, E> Future for RouteFuture<B, E>
151 where
152     B: HttpBody,
153 {
154     type Output = Result<Response, E>;
155 
156     #[inline]
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>157     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158         let this = self.project();
159 
160         let mut res = match this.kind.project() {
161             RouteFutureKindProj::Future { future } => match future.poll(cx) {
162                 Poll::Ready(Ok(res)) => res,
163                 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
164                 Poll::Pending => return Poll::Pending,
165             },
166             RouteFutureKindProj::Response { response } => {
167                 response.take().expect("future polled after completion")
168             }
169         };
170 
171         set_allow_header(res.headers_mut(), this.allow_header);
172 
173         // make sure to set content-length before removing the body
174         set_content_length(res.size_hint(), res.headers_mut());
175 
176         let res = if *this.strip_body {
177             res.map(|_| boxed(Empty::new()))
178         } else {
179             res
180         };
181 
182         Poll::Ready(Ok(res))
183     }
184 }
185 
set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>)186 fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>) {
187     match allow_header.take() {
188         Some(allow_header) if !headers.contains_key(header::ALLOW) => {
189             headers.insert(
190                 header::ALLOW,
191                 HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"),
192             );
193         }
194         _ => {}
195     }
196 }
197 
set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap)198 fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) {
199     if headers.contains_key(CONTENT_LENGTH) {
200         return;
201     }
202 
203     if let Some(size) = size_hint.exact() {
204         let header_value = if size == 0 {
205             #[allow(clippy::declare_interior_mutable_const)]
206             const ZERO: HeaderValue = HeaderValue::from_static("0");
207 
208             ZERO
209         } else {
210             let mut buffer = itoa::Buffer::new();
211             HeaderValue::from_str(buffer.format(size)).unwrap()
212         };
213 
214         headers.insert(CONTENT_LENGTH, header_value);
215     }
216 }
217 
218 pin_project! {
219     /// A [`RouteFuture`] that always yields a [`Response`].
220     pub struct InfallibleRouteFuture<B> {
221         #[pin]
222         future: RouteFuture<B, Infallible>,
223     }
224 }
225 
226 impl<B> InfallibleRouteFuture<B> {
new(future: RouteFuture<B, Infallible>) -> Self227     pub(crate) fn new(future: RouteFuture<B, Infallible>) -> Self {
228         Self { future }
229     }
230 }
231 
232 impl<B> Future for InfallibleRouteFuture<B>
233 where
234     B: HttpBody,
235 {
236     type Output = Response;
237 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>238     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
239         match futures_util::ready!(self.project().future.poll(cx)) {
240             Ok(response) => Poll::Ready(response),
241             Err(err) => match err {},
242         }
243     }
244 }
245 
246 #[cfg(test)]
247 mod tests {
248     use super::*;
249 
250     #[test]
traits()251     fn traits() {
252         use crate::test_helpers::*;
253         assert_send::<Route<()>>();
254     }
255 }
256