1 use crate::{
2 extract::FromRequestParts,
3 response::{IntoResponse, Response},
4 };
5 use futures_util::{future::BoxFuture, ready};
6 use http::Request;
7 use pin_project_lite::pin_project;
8 use std::{
9 fmt,
10 future::Future,
11 marker::PhantomData,
12 pin::Pin,
13 task::{Context, Poll},
14 };
15 use tower_layer::Layer;
16 use tower_service::Service;
17
18 /// Create a middleware from an extractor.
19 ///
20 /// If the extractor succeeds the value will be discarded and the inner service
21 /// will be called. If the extractor fails the rejection will be returned and
22 /// the inner service will _not_ be called.
23 ///
24 /// This can be used to perform validation of requests if the validation doesn't
25 /// produce any useful output, and run the extractor for several handlers
26 /// without repeating it in the function signature.
27 ///
28 /// Note that if the extractor consumes the request body, as `String` or
29 /// [`Bytes`] does, an empty body will be left in its place. Thus wont be
30 /// accessible to subsequent extractors or handlers.
31 ///
32 /// # Example
33 ///
34 /// ```rust
35 /// use axum::{
36 /// extract::FromRequestParts,
37 /// middleware::from_extractor,
38 /// routing::{get, post},
39 /// Router,
40 /// http::{header, StatusCode, request::Parts},
41 /// };
42 /// use async_trait::async_trait;
43 ///
44 /// // An extractor that performs authorization.
45 /// struct RequireAuth;
46 ///
47 /// #[async_trait]
48 /// impl<S> FromRequestParts<S> for RequireAuth
49 /// where
50 /// S: Send + Sync,
51 /// {
52 /// type Rejection = StatusCode;
53 ///
54 /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
55 /// let auth_header = parts
56 /// .headers
57 /// .get(header::AUTHORIZATION)
58 /// .and_then(|value| value.to_str().ok());
59 ///
60 /// match auth_header {
61 /// Some(auth_header) if token_is_valid(auth_header) => {
62 /// Ok(Self)
63 /// }
64 /// _ => Err(StatusCode::UNAUTHORIZED),
65 /// }
66 /// }
67 /// }
68 ///
69 /// fn token_is_valid(token: &str) -> bool {
70 /// // ...
71 /// # false
72 /// }
73 ///
74 /// async fn handler() {
75 /// // If we get here the request has been authorized
76 /// }
77 ///
78 /// async fn other_handler() {
79 /// // If we get here the request has been authorized
80 /// }
81 ///
82 /// let app = Router::new()
83 /// .route("/", get(handler))
84 /// .route("/foo", post(other_handler))
85 /// // The extractor will run before all routes
86 /// .route_layer(from_extractor::<RequireAuth>());
87 /// # async {
88 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
89 /// # };
90 /// ```
91 ///
92 /// [`Bytes`]: bytes::Bytes
from_extractor<E>() -> FromExtractorLayer<E, ()>93 pub fn from_extractor<E>() -> FromExtractorLayer<E, ()> {
94 from_extractor_with_state(())
95 }
96
97 /// Create a middleware from an extractor with the given state.
98 ///
99 /// See [`State`](crate::extract::State) for more details about accessing state.
from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S>100 pub fn from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S> {
101 FromExtractorLayer {
102 state,
103 _marker: PhantomData,
104 }
105 }
106
107 /// [`Layer`] that applies [`FromExtractor`] that runs an extractor and
108 /// discards the value.
109 ///
110 /// See [`from_extractor`] for more details.
111 ///
112 /// [`Layer`]: tower::Layer
113 #[must_use]
114 pub struct FromExtractorLayer<E, S> {
115 state: S,
116 _marker: PhantomData<fn() -> E>,
117 }
118
119 impl<E, S> Clone for FromExtractorLayer<E, S>
120 where
121 S: Clone,
122 {
clone(&self) -> Self123 fn clone(&self) -> Self {
124 Self {
125 state: self.state.clone(),
126 _marker: PhantomData,
127 }
128 }
129 }
130
131 impl<E, S> fmt::Debug for FromExtractorLayer<E, S>
132 where
133 S: fmt::Debug,
134 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 f.debug_struct("FromExtractorLayer")
137 .field("state", &self.state)
138 .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
139 .finish()
140 }
141 }
142
143 impl<E, T, S> Layer<T> for FromExtractorLayer<E, S>
144 where
145 S: Clone,
146 {
147 type Service = FromExtractor<T, E, S>;
148
layer(&self, inner: T) -> Self::Service149 fn layer(&self, inner: T) -> Self::Service {
150 FromExtractor {
151 inner,
152 state: self.state.clone(),
153 _extractor: PhantomData,
154 }
155 }
156 }
157
158 /// Middleware that runs an extractor and discards the value.
159 ///
160 /// See [`from_extractor`] for more details.
161 pub struct FromExtractor<T, E, S> {
162 inner: T,
163 state: S,
164 _extractor: PhantomData<fn() -> E>,
165 }
166
167 #[test]
traits()168 fn traits() {
169 use crate::test_helpers::*;
170 assert_send::<FromExtractor<(), NotSendSync, ()>>();
171 assert_sync::<FromExtractor<(), NotSendSync, ()>>();
172 }
173
174 impl<T, E, S> Clone for FromExtractor<T, E, S>
175 where
176 T: Clone,
177 S: Clone,
178 {
clone(&self) -> Self179 fn clone(&self) -> Self {
180 Self {
181 inner: self.inner.clone(),
182 state: self.state.clone(),
183 _extractor: PhantomData,
184 }
185 }
186 }
187
188 impl<T, E, S> fmt::Debug for FromExtractor<T, E, S>
189 where
190 T: fmt::Debug,
191 S: fmt::Debug,
192 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result193 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194 f.debug_struct("FromExtractor")
195 .field("inner", &self.inner)
196 .field("state", &self.state)
197 .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
198 .finish()
199 }
200 }
201
202 impl<T, E, B, S> Service<Request<B>> for FromExtractor<T, E, S>
203 where
204 E: FromRequestParts<S> + 'static,
205 B: Send + 'static,
206 T: Service<Request<B>> + Clone,
207 T::Response: IntoResponse,
208 S: Clone + Send + Sync + 'static,
209 {
210 type Response = Response;
211 type Error = T::Error;
212 type Future = ResponseFuture<B, T, E, S>;
213
214 #[inline]
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>215 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216 self.inner.poll_ready(cx)
217 }
218
call(&mut self, req: Request<B>) -> Self::Future219 fn call(&mut self, req: Request<B>) -> Self::Future {
220 let state = self.state.clone();
221 let extract_future = Box::pin(async move {
222 let (mut parts, body) = req.into_parts();
223 let extracted = E::from_request_parts(&mut parts, &state).await;
224 let req = Request::from_parts(parts, body);
225 (req, extracted)
226 });
227
228 ResponseFuture {
229 state: State::Extracting {
230 future: extract_future,
231 },
232 svc: Some(self.inner.clone()),
233 }
234 }
235 }
236
237 pin_project! {
238 /// Response future for [`FromExtractor`].
239 #[allow(missing_debug_implementations)]
240 pub struct ResponseFuture<B, T, E, S>
241 where
242 E: FromRequestParts<S>,
243 T: Service<Request<B>>,
244 {
245 #[pin]
246 state: State<B, T, E, S>,
247 svc: Option<T>,
248 }
249 }
250
251 pin_project! {
252 #[project = StateProj]
253 enum State<B, T, E, S>
254 where
255 E: FromRequestParts<S>,
256 T: Service<Request<B>>,
257 {
258 Extracting {
259 future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
260 },
261 Call { #[pin] future: T::Future },
262 }
263 }
264
265 impl<B, T, E, S> Future for ResponseFuture<B, T, E, S>
266 where
267 E: FromRequestParts<S>,
268 T: Service<Request<B>>,
269 T::Response: IntoResponse,
270 {
271 type Output = Result<Response, T::Error>;
272
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>273 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
274 loop {
275 let mut this = self.as_mut().project();
276
277 let new_state = match this.state.as_mut().project() {
278 StateProj::Extracting { future } => {
279 let (req, extracted) = ready!(future.as_mut().poll(cx));
280
281 match extracted {
282 Ok(_) => {
283 let mut svc = this.svc.take().expect("future polled after completion");
284 let future = svc.call(req);
285 State::Call { future }
286 }
287 Err(err) => {
288 let res = err.into_response();
289 return Poll::Ready(Ok(res));
290 }
291 }
292 }
293 StateProj::Call { future } => {
294 return future
295 .poll(cx)
296 .map(|result| result.map(IntoResponse::into_response));
297 }
298 };
299
300 this.state.set(new_state);
301 }
302 }
303 }
304
305 #[cfg(test)]
306 mod tests {
307 use super::*;
308 use crate::{async_trait, handler::Handler, routing::get, test_helpers::*, Router};
309 use axum_core::extract::FromRef;
310 use http::{header, request::Parts, StatusCode};
311 use tower_http::limit::RequestBodyLimitLayer;
312
313 #[crate::test]
test_from_extractor()314 async fn test_from_extractor() {
315 #[derive(Clone)]
316 struct Secret(&'static str);
317
318 struct RequireAuth;
319
320 #[async_trait::async_trait]
321 impl<S> FromRequestParts<S> for RequireAuth
322 where
323 S: Send + Sync,
324 Secret: FromRef<S>,
325 {
326 type Rejection = StatusCode;
327
328 async fn from_request_parts(
329 parts: &mut Parts,
330 state: &S,
331 ) -> Result<Self, Self::Rejection> {
332 let Secret(secret) = Secret::from_ref(state);
333 if let Some(auth) = parts
334 .headers
335 .get(header::AUTHORIZATION)
336 .and_then(|v| v.to_str().ok())
337 {
338 if auth == secret {
339 return Ok(Self);
340 }
341 }
342
343 Err(StatusCode::UNAUTHORIZED)
344 }
345 }
346
347 async fn handler() {}
348
349 let state = Secret("secret");
350 let app = Router::new().route(
351 "/",
352 get(handler.layer(from_extractor_with_state::<RequireAuth, _>(state))),
353 );
354
355 let client = TestClient::new(app);
356
357 let res = client.get("/").send().await;
358 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
359
360 let res = client
361 .get("/")
362 .header(http::header::AUTHORIZATION, "secret")
363 .send()
364 .await;
365 assert_eq!(res.status(), StatusCode::OK);
366 }
367
368 // just needs to compile
369 #[allow(dead_code)]
works_with_request_body_limit()370 fn works_with_request_body_limit() {
371 struct MyExtractor;
372
373 #[async_trait]
374 impl<S> FromRequestParts<S> for MyExtractor
375 where
376 S: Send + Sync,
377 {
378 type Rejection = std::convert::Infallible;
379
380 async fn from_request_parts(
381 _parts: &mut Parts,
382 _state: &S,
383 ) -> Result<Self, Self::Rejection> {
384 unimplemented!()
385 }
386 }
387
388 let _: Router = Router::new()
389 .layer(from_extractor::<MyExtractor>())
390 .layer(RequestBodyLimitLayer::new(1));
391 }
392 }
393