1 use crate::{
2 body::{boxed, Body, Empty, HttpBody},
3 response::Response,
4 };
5 use axum_core::response::IntoResponse;
6 use bytes::Bytes;
7 use http::{
8 header::{self, CONTENT_LENGTH},
9 HeaderMap, HeaderValue, Request,
10 };
11 use pin_project_lite::pin_project;
12 use std::{
13 convert::Infallible,
14 fmt,
15 future::Future,
16 pin::Pin,
17 task::{Context, Poll},
18 };
19 use tower::{
20 util::{BoxCloneService, MapResponseLayer, Oneshot},
21 ServiceBuilder, ServiceExt,
22 };
23 use tower_layer::Layer;
24 use tower_service::Service;
25
26 /// How routes are stored inside a [`Router`](super::Router).
27 ///
28 /// You normally shouldn't need to care about this type. It's used in
29 /// [`Router::layer`](super::Router::layer).
30 pub struct Route<B = Body, E = Infallible>(BoxCloneService<Request<B>, Response, E>);
31
32 impl<B, E> Route<B, E> {
new<T>(svc: T) -> Self where T: Service<Request<B>, Error = E> + Clone + Send + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static,33 pub(crate) fn new<T>(svc: T) -> Self
34 where
35 T: Service<Request<B>, Error = E> + Clone + Send + 'static,
36 T::Response: IntoResponse + 'static,
37 T::Future: Send + 'static,
38 {
39 Self(BoxCloneService::new(
40 svc.map_response(IntoResponse::into_response),
41 ))
42 }
43
oneshot_inner( &mut self, req: Request<B>, ) -> Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>>44 pub(crate) fn oneshot_inner(
45 &mut self,
46 req: Request<B>,
47 ) -> Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>> {
48 self.0.clone().oneshot(req)
49 }
50
layer<L, NewReqBody, NewError>(self, layer: L) -> Route<NewReqBody, NewError> where L: Layer<Route<B, E>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, NewReqBody: 'static, NewError: 'static,51 pub(crate) fn layer<L, NewReqBody, NewError>(self, layer: L) -> Route<NewReqBody, NewError>
52 where
53 L: Layer<Route<B, E>> + Clone + Send + 'static,
54 L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
55 <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
56 <L::Service as Service<Request<NewReqBody>>>::Error: Into<NewError> + 'static,
57 <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
58 NewReqBody: 'static,
59 NewError: 'static,
60 {
61 let layer = ServiceBuilder::new()
62 .map_err(Into::into)
63 .layer(MapResponseLayer::new(IntoResponse::into_response))
64 .layer(layer)
65 .into_inner();
66
67 Route::new(layer.layer(self))
68 }
69 }
70
71 impl<B, E> Clone for Route<B, E> {
clone(&self) -> Self72 fn clone(&self) -> Self {
73 Self(self.0.clone())
74 }
75 }
76
77 impl<B, E> fmt::Debug for Route<B, E> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 f.debug_struct("Route").finish()
80 }
81 }
82
83 impl<B, E> Service<Request<B>> for Route<B, E>
84 where
85 B: HttpBody,
86 {
87 type Response = Response;
88 type Error = E;
89 type Future = RouteFuture<B, E>;
90
91 #[inline]
poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>92 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
93 Poll::Ready(Ok(()))
94 }
95
96 #[inline]
call(&mut self, req: Request<B>) -> Self::Future97 fn call(&mut self, req: Request<B>) -> Self::Future {
98 RouteFuture::from_future(self.oneshot_inner(req))
99 }
100 }
101
102 pin_project! {
103 /// Response future for [`Route`].
104 pub struct RouteFuture<B, E> {
105 #[pin]
106 kind: RouteFutureKind<B, E>,
107 strip_body: bool,
108 allow_header: Option<Bytes>,
109 }
110 }
111
112 pin_project! {
113 #[project = RouteFutureKindProj]
114 enum RouteFutureKind<B, E> {
115 Future {
116 #[pin]
117 future: Oneshot<
118 BoxCloneService<Request<B>, Response, E>,
119 Request<B>,
120 >,
121 },
122 Response {
123 response: Option<Response>,
124 }
125 }
126 }
127
128 impl<B, E> RouteFuture<B, E> {
from_future( future: Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>>, ) -> Self129 pub(crate) fn from_future(
130 future: Oneshot<BoxCloneService<Request<B>, Response, E>, Request<B>>,
131 ) -> Self {
132 Self {
133 kind: RouteFutureKind::Future { future },
134 strip_body: false,
135 allow_header: None,
136 }
137 }
138
strip_body(mut self, strip_body: bool) -> Self139 pub(crate) fn strip_body(mut self, strip_body: bool) -> Self {
140 self.strip_body = strip_body;
141 self
142 }
143
allow_header(mut self, allow_header: Bytes) -> Self144 pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self {
145 self.allow_header = Some(allow_header);
146 self
147 }
148 }
149
150 impl<B, E> Future for RouteFuture<B, E>
151 where
152 B: HttpBody,
153 {
154 type Output = Result<Response, E>;
155
156 #[inline]
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>157 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158 let this = self.project();
159
160 let mut res = match this.kind.project() {
161 RouteFutureKindProj::Future { future } => match future.poll(cx) {
162 Poll::Ready(Ok(res)) => res,
163 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
164 Poll::Pending => return Poll::Pending,
165 },
166 RouteFutureKindProj::Response { response } => {
167 response.take().expect("future polled after completion")
168 }
169 };
170
171 set_allow_header(res.headers_mut(), this.allow_header);
172
173 // make sure to set content-length before removing the body
174 set_content_length(res.size_hint(), res.headers_mut());
175
176 let res = if *this.strip_body {
177 res.map(|_| boxed(Empty::new()))
178 } else {
179 res
180 };
181
182 Poll::Ready(Ok(res))
183 }
184 }
185
set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>)186 fn set_allow_header(headers: &mut HeaderMap, allow_header: &mut Option<Bytes>) {
187 match allow_header.take() {
188 Some(allow_header) if !headers.contains_key(header::ALLOW) => {
189 headers.insert(
190 header::ALLOW,
191 HeaderValue::from_maybe_shared(allow_header).expect("invalid `Allow` header"),
192 );
193 }
194 _ => {}
195 }
196 }
197
set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap)198 fn set_content_length(size_hint: http_body::SizeHint, headers: &mut HeaderMap) {
199 if headers.contains_key(CONTENT_LENGTH) {
200 return;
201 }
202
203 if let Some(size) = size_hint.exact() {
204 let header_value = if size == 0 {
205 #[allow(clippy::declare_interior_mutable_const)]
206 const ZERO: HeaderValue = HeaderValue::from_static("0");
207
208 ZERO
209 } else {
210 let mut buffer = itoa::Buffer::new();
211 HeaderValue::from_str(buffer.format(size)).unwrap()
212 };
213
214 headers.insert(CONTENT_LENGTH, header_value);
215 }
216 }
217
218 pin_project! {
219 /// A [`RouteFuture`] that always yields a [`Response`].
220 pub struct InfallibleRouteFuture<B> {
221 #[pin]
222 future: RouteFuture<B, Infallible>,
223 }
224 }
225
226 impl<B> InfallibleRouteFuture<B> {
new(future: RouteFuture<B, Infallible>) -> Self227 pub(crate) fn new(future: RouteFuture<B, Infallible>) -> Self {
228 Self { future }
229 }
230 }
231
232 impl<B> Future for InfallibleRouteFuture<B>
233 where
234 B: HttpBody,
235 {
236 type Output = Response;
237
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>238 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
239 match futures_util::ready!(self.project().future.poll(cx)) {
240 Ok(response) => Poll::Ready(response),
241 Err(err) => match err {},
242 }
243 }
244 }
245
246 #[cfg(test)]
247 mod tests {
248 use super::*;
249
250 #[test]
traits()251 fn traits() {
252 use crate::test_helpers::*;
253 assert_send::<Route<()>>();
254 }
255 }
256