1 use crate::response::{IntoResponse, Response};
2 use axum_core::extract::{FromRequest, FromRequestParts};
3 use futures_util::future::BoxFuture;
4 use http::Request;
5 use std::{
6 any::type_name,
7 convert::Infallible,
8 fmt,
9 future::Future,
10 marker::PhantomData,
11 pin::Pin,
12 task::{Context, Poll},
13 };
14 use tower_layer::Layer;
15 use tower_service::Service;
16
17 /// Create a middleware from an async function that transforms a request.
18 ///
19 /// This differs from [`tower::util::MapRequest`] in that it allows you to easily run axum-specific
20 /// extractors.
21 ///
22 /// # Example
23 ///
24 /// ```
25 /// use axum::{
26 /// Router,
27 /// routing::get,
28 /// middleware::map_request,
29 /// http::Request,
30 /// };
31 ///
32 /// async fn set_header<B>(mut request: Request<B>) -> Request<B> {
33 /// request.headers_mut().insert("x-foo", "foo".parse().unwrap());
34 /// request
35 /// }
36 ///
37 /// async fn handler<B>(request: Request<B>) {
38 /// // `request` will have an `x-foo` header
39 /// }
40 ///
41 /// let app = Router::new()
42 /// .route("/", get(handler))
43 /// .layer(map_request(set_header));
44 /// # let _: Router = app;
45 /// ```
46 ///
47 /// # Rejecting the request
48 ///
49 /// The function given to `map_request` is allowed to also return a `Result` which can be used to
50 /// reject the request and return a response immediately, without calling the remaining
51 /// middleware.
52 ///
53 /// Specifically the valid return types are:
54 ///
55 /// - `Request<B>`
56 /// - `Result<Request<B>, E> where E: IntoResponse`
57 ///
58 /// ```
59 /// use axum::{
60 /// Router,
61 /// http::{Request, StatusCode},
62 /// routing::get,
63 /// middleware::map_request,
64 /// };
65 ///
66 /// async fn auth<B>(request: Request<B>) -> Result<Request<B>, StatusCode> {
67 /// let auth_header = request.headers()
68 /// .get(http::header::AUTHORIZATION)
69 /// .and_then(|header| header.to_str().ok());
70 ///
71 /// match auth_header {
72 /// Some(auth_header) if token_is_valid(auth_header) => Ok(request),
73 /// _ => Err(StatusCode::UNAUTHORIZED),
74 /// }
75 /// }
76 ///
77 /// fn token_is_valid(token: &str) -> bool {
78 /// // ...
79 /// # false
80 /// }
81 ///
82 /// let app = Router::new()
83 /// .route("/", get(|| async { /* ... */ }))
84 /// .route_layer(map_request(auth));
85 /// # let app: Router = app;
86 /// ```
87 ///
88 /// # Running extractors
89 ///
90 /// ```
91 /// use axum::{
92 /// Router,
93 /// routing::get,
94 /// middleware::map_request,
95 /// extract::Path,
96 /// http::Request,
97 /// };
98 /// use std::collections::HashMap;
99 ///
100 /// async fn log_path_params<B>(
101 /// Path(path_params): Path<HashMap<String, String>>,
102 /// request: Request<B>,
103 /// ) -> Request<B> {
104 /// tracing::debug!(?path_params);
105 /// request
106 /// }
107 ///
108 /// let app = Router::new()
109 /// .route("/", get(|| async { /* ... */ }))
110 /// .layer(map_request(log_path_params));
111 /// # let _: Router = app;
112 /// ```
113 ///
114 /// Note that to access state you must use either [`map_request_with_state`].
map_request<F, T>(f: F) -> MapRequestLayer<F, (), T>115 pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T> {
116 map_request_with_state((), f)
117 }
118
119 /// Create a middleware from an async function that transforms a request, with the given state.
120 ///
121 /// See [`State`](crate::extract::State) for more details about accessing state.
122 ///
123 /// # Example
124 ///
125 /// ```rust
126 /// use axum::{
127 /// Router,
128 /// http::{Request, StatusCode},
129 /// routing::get,
130 /// response::IntoResponse,
131 /// middleware::map_request_with_state,
132 /// extract::State,
133 /// };
134 ///
135 /// #[derive(Clone)]
136 /// struct AppState { /* ... */ }
137 ///
138 /// async fn my_middleware<B>(
139 /// State(state): State<AppState>,
140 /// // you can add more extractors here but the last
141 /// // extractor must implement `FromRequest` which
142 /// // `Request` does
143 /// request: Request<B>,
144 /// ) -> Request<B> {
145 /// // do something with `state` and `request`...
146 /// request
147 /// }
148 ///
149 /// let state = AppState { /* ... */ };
150 ///
151 /// let app = Router::new()
152 /// .route("/", get(|| async { /* ... */ }))
153 /// .route_layer(map_request_with_state(state.clone(), my_middleware))
154 /// .with_state(state);
155 /// # let _: axum::Router = app;
156 /// ```
map_request_with_state<F, S, T>(state: S, f: F) -> MapRequestLayer<F, S, T>157 pub fn map_request_with_state<F, S, T>(state: S, f: F) -> MapRequestLayer<F, S, T> {
158 MapRequestLayer {
159 f,
160 state,
161 _extractor: PhantomData,
162 }
163 }
164
165 /// A [`tower::Layer`] from an async function that transforms a request.
166 ///
167 /// Created with [`map_request`]. See that function for more details.
168 #[must_use]
169 pub struct MapRequestLayer<F, S, T> {
170 f: F,
171 state: S,
172 _extractor: PhantomData<fn() -> T>,
173 }
174
175 impl<F, S, T> Clone for MapRequestLayer<F, S, T>
176 where
177 F: Clone,
178 S: Clone,
179 {
clone(&self) -> Self180 fn clone(&self) -> Self {
181 Self {
182 f: self.f.clone(),
183 state: self.state.clone(),
184 _extractor: self._extractor,
185 }
186 }
187 }
188
189 impl<S, I, F, T> Layer<I> for MapRequestLayer<F, S, T>
190 where
191 F: Clone,
192 S: Clone,
193 {
194 type Service = MapRequest<F, S, I, T>;
195
layer(&self, inner: I) -> Self::Service196 fn layer(&self, inner: I) -> Self::Service {
197 MapRequest {
198 f: self.f.clone(),
199 state: self.state.clone(),
200 inner,
201 _extractor: PhantomData,
202 }
203 }
204 }
205
206 impl<F, S, T> fmt::Debug for MapRequestLayer<F, S, T>
207 where
208 S: fmt::Debug,
209 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 f.debug_struct("MapRequestLayer")
212 // Write out the type name, without quoting it as `&type_name::<F>()` would
213 .field("f", &format_args!("{}", type_name::<F>()))
214 .field("state", &self.state)
215 .finish()
216 }
217 }
218
219 /// A middleware created from an async function that transforms a request.
220 ///
221 /// Created with [`map_request`]. See that function for more details.
222 pub struct MapRequest<F, S, I, T> {
223 f: F,
224 inner: I,
225 state: S,
226 _extractor: PhantomData<fn() -> T>,
227 }
228
229 impl<F, S, I, T> Clone for MapRequest<F, S, I, T>
230 where
231 F: Clone,
232 I: Clone,
233 S: Clone,
234 {
clone(&self) -> Self235 fn clone(&self) -> Self {
236 Self {
237 f: self.f.clone(),
238 inner: self.inner.clone(),
239 state: self.state.clone(),
240 _extractor: self._extractor,
241 }
242 }
243 }
244
245 macro_rules! impl_service {
246 (
247 [$($ty:ident),*], $last:ident
248 ) => {
249 #[allow(non_snake_case, unused_mut)]
250 impl<F, Fut, S, I, B, $($ty,)* $last> Service<Request<B>> for MapRequest<F, S, I, ($($ty,)* $last,)>
251 where
252 F: FnMut($($ty,)* $last) -> Fut + Clone + Send + 'static,
253 $( $ty: FromRequestParts<S> + Send, )*
254 $last: FromRequest<S, B> + Send,
255 Fut: Future + Send + 'static,
256 Fut::Output: IntoMapRequestResult<B> + Send + 'static,
257 I: Service<Request<B>, Error = Infallible>
258 + Clone
259 + Send
260 + 'static,
261 I::Response: IntoResponse,
262 I::Future: Send + 'static,
263 B: Send + 'static,
264 S: Clone + Send + Sync + 'static,
265 {
266 type Response = Response;
267 type Error = Infallible;
268 type Future = ResponseFuture;
269
270 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
271 self.inner.poll_ready(cx)
272 }
273
274 fn call(&mut self, req: Request<B>) -> Self::Future {
275 let not_ready_inner = self.inner.clone();
276 let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
277
278 let mut f = self.f.clone();
279 let state = self.state.clone();
280
281 let future = Box::pin(async move {
282 let (mut parts, body) = req.into_parts();
283
284 $(
285 let $ty = match $ty::from_request_parts(&mut parts, &state).await {
286 Ok(value) => value,
287 Err(rejection) => return rejection.into_response(),
288 };
289 )*
290
291 let req = Request::from_parts(parts, body);
292
293 let $last = match $last::from_request(req, &state).await {
294 Ok(value) => value,
295 Err(rejection) => return rejection.into_response(),
296 };
297
298 match f($($ty,)* $last).await.into_map_request_result() {
299 Ok(req) => {
300 ready_inner.call(req).await.into_response()
301 }
302 Err(res) => {
303 res
304 }
305 }
306 });
307
308 ResponseFuture {
309 inner: future
310 }
311 }
312 }
313 };
314 }
315
316 all_the_tuples!(impl_service);
317
318 impl<F, S, I, T> fmt::Debug for MapRequest<F, S, I, T>
319 where
320 S: fmt::Debug,
321 I: fmt::Debug,
322 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result323 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324 f.debug_struct("MapRequest")
325 .field("f", &format_args!("{}", type_name::<F>()))
326 .field("inner", &self.inner)
327 .field("state", &self.state)
328 .finish()
329 }
330 }
331
332 /// Response future for [`MapRequest`].
333 pub struct ResponseFuture {
334 inner: BoxFuture<'static, Response>,
335 }
336
337 impl Future for ResponseFuture {
338 type Output = Result<Response, Infallible>;
339
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>340 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
341 self.inner.as_mut().poll(cx).map(Ok)
342 }
343 }
344
345 impl fmt::Debug for ResponseFuture {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result346 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347 f.debug_struct("ResponseFuture").finish()
348 }
349 }
350
351 mod private {
352 use crate::{http::Request, response::IntoResponse};
353
354 pub trait Sealed<B> {}
355 impl<B, E> Sealed<B> for Result<Request<B>, E> where E: IntoResponse {}
356 impl<B> Sealed<B> for Request<B> {}
357 }
358
359 /// Trait implemented by types that can be returned from [`map_request`],
360 /// [`map_request_with_state`].
361 ///
362 /// This trait is sealed such that it cannot be implemented outside this crate.
363 pub trait IntoMapRequestResult<B>: private::Sealed<B> {
364 /// Perform the conversion.
into_map_request_result(self) -> Result<Request<B>, Response>365 fn into_map_request_result(self) -> Result<Request<B>, Response>;
366 }
367
368 impl<B, E> IntoMapRequestResult<B> for Result<Request<B>, E>
369 where
370 E: IntoResponse,
371 {
into_map_request_result(self) -> Result<Request<B>, Response>372 fn into_map_request_result(self) -> Result<Request<B>, Response> {
373 self.map_err(IntoResponse::into_response)
374 }
375 }
376
377 impl<B> IntoMapRequestResult<B> for Request<B> {
into_map_request_result(self) -> Result<Request<B>, Response>378 fn into_map_request_result(self) -> Result<Request<B>, Response> {
379 Ok(self)
380 }
381 }
382
383 #[cfg(test)]
384 mod tests {
385 use super::*;
386 use crate::{routing::get, test_helpers::TestClient, Router};
387 use http::{HeaderMap, StatusCode};
388
389 #[crate::test]
works()390 async fn works() {
391 async fn add_header<B>(mut req: Request<B>) -> Request<B> {
392 req.headers_mut().insert("x-foo", "foo".parse().unwrap());
393 req
394 }
395
396 async fn handler(headers: HeaderMap) -> Response {
397 headers["x-foo"]
398 .to_str()
399 .unwrap()
400 .to_owned()
401 .into_response()
402 }
403
404 let app = Router::new()
405 .route("/", get(handler))
406 .layer(map_request(add_header));
407 let client = TestClient::new(app);
408
409 let res = client.get("/").send().await;
410
411 assert_eq!(res.text().await, "foo");
412 }
413
414 #[crate::test]
works_for_short_circutting()415 async fn works_for_short_circutting() {
416 async fn add_header<B>(_req: Request<B>) -> Result<Request<B>, (StatusCode, &'static str)> {
417 Err((StatusCode::INTERNAL_SERVER_ERROR, "something went wrong"))
418 }
419
420 async fn handler(_headers: HeaderMap) -> Response {
421 unreachable!()
422 }
423
424 let app = Router::new()
425 .route("/", get(handler))
426 .layer(map_request(add_header));
427 let client = TestClient::new(app);
428
429 let res = client.get("/").send().await;
430
431 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
432 assert_eq!(res.text().await, "something went wrong");
433 }
434 }
435