1 use super::{rejection::*, FromRequestParts};
2 use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE};
3 use async_trait::async_trait;
4 use http::request::Parts;
5 use std::{collections::HashMap, sync::Arc};
6 
7 /// Access the path in the router that matches the request.
8 ///
9 /// ```
10 /// use axum::{
11 ///     Router,
12 ///     extract::MatchedPath,
13 ///     routing::get,
14 /// };
15 ///
16 /// let app = Router::new().route(
17 ///     "/users/:id",
18 ///     get(|path: MatchedPath| async move {
19 ///         let path = path.as_str();
20 ///         // `path` will be "/users/:id"
21 ///     })
22 /// );
23 /// # async {
24 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
25 /// # };
26 /// ```
27 ///
28 /// # Accessing `MatchedPath` via extensions
29 ///
30 /// `MatchedPath` can also be accessed from middleware via request extensions.
31 ///
32 /// This is useful for example with [`Trace`](tower_http::trace::Trace) to
33 /// create a span that contains the matched path:
34 ///
35 /// ```
36 /// use axum::{
37 ///     Router,
38 ///     extract::MatchedPath,
39 ///     http::Request,
40 ///     routing::get,
41 /// };
42 /// use tower_http::trace::TraceLayer;
43 ///
44 /// let app = Router::new()
45 ///     .route("/users/:id", get(|| async { /* ... */ }))
46 ///     .layer(
47 ///         TraceLayer::new_for_http().make_span_with(|req: &Request<_>| {
48 ///             let path = if let Some(path) = req.extensions().get::<MatchedPath>() {
49 ///                 path.as_str()
50 ///             } else {
51 ///                 req.uri().path()
52 ///             };
53 ///             tracing::info_span!("http-request", %path)
54 ///         }),
55 ///     );
56 /// # let _: Router = app;
57 /// ```
58 ///
59 /// # Matched path in nested routers
60 ///
61 /// Because of how [nesting] works `MatchedPath` isn't accessible in middleware on nested routes:
62 ///
63 /// ```
64 /// use axum::{
65 ///     Router,
66 ///     RequestExt,
67 ///     routing::get,
68 ///     extract::{MatchedPath, rejection::MatchedPathRejection},
69 ///     middleware::map_request,
70 ///     http::Request,
71 ///     body::Body,
72 /// };
73 ///
74 /// async fn access_matched_path(mut request: Request<Body>) -> Request<Body> {
75 ///     // if `/foo/bar` is called this will be `Err(_)` since that matches
76 ///     // a nested route
77 ///     let matched_path: Result<MatchedPath, MatchedPathRejection> =
78 ///         request.extract_parts::<MatchedPath>().await;
79 ///
80 ///     request
81 /// }
82 ///
83 /// // `MatchedPath` is always accessible on handlers added via `Router::route`
84 /// async fn handler(matched_path: MatchedPath) {}
85 ///
86 /// let app = Router::new()
87 ///     .nest(
88 ///         "/foo",
89 ///         Router::new().route("/bar", get(handler)),
90 ///     )
91 ///     .layer(map_request(access_matched_path));
92 /// # let _: Router = app;
93 /// ```
94 ///
95 /// [nesting]: crate::Router::nest
96 #[cfg_attr(docsrs, doc(cfg(feature = "matched-path")))]
97 #[derive(Clone, Debug)]
98 pub struct MatchedPath(pub(crate) Arc<str>);
99 
100 impl MatchedPath {
101     /// Returns a `str` representation of the path.
as_str(&self) -> &str102     pub fn as_str(&self) -> &str {
103         &self.0
104     }
105 }
106 
107 #[async_trait]
108 impl<S> FromRequestParts<S> for MatchedPath
109 where
110     S: Send + Sync,
111 {
112     type Rejection = MatchedPathRejection;
113 
from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection>114     async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
115         let matched_path = parts
116             .extensions
117             .get::<Self>()
118             .ok_or(MatchedPathRejection::MatchedPathMissing(MatchedPathMissing))?
119             .clone();
120 
121         Ok(matched_path)
122     }
123 }
124 
125 #[derive(Clone, Debug)]
126 struct MatchedNestedPath(Arc<str>);
127 
set_matched_path_for_request( id: RouteId, route_id_to_path: &HashMap<RouteId, Arc<str>>, extensions: &mut http::Extensions, )128 pub(crate) fn set_matched_path_for_request(
129     id: RouteId,
130     route_id_to_path: &HashMap<RouteId, Arc<str>>,
131     extensions: &mut http::Extensions,
132 ) {
133     let matched_path = if let Some(matched_path) = route_id_to_path.get(&id) {
134         matched_path
135     } else {
136         #[cfg(debug_assertions)]
137         panic!("should always have a matched path for a route id");
138         #[cfg(not(debug_assertions))]
139         return;
140     };
141 
142     let matched_path = append_nested_matched_path(matched_path, extensions);
143 
144     if matched_path.ends_with(NEST_TAIL_PARAM_CAPTURE) {
145         extensions.insert(MatchedNestedPath(matched_path));
146         debug_assert!(extensions.remove::<MatchedPath>().is_none());
147     } else {
148         extensions.insert(MatchedPath(matched_path));
149         extensions.remove::<MatchedNestedPath>();
150     }
151 }
152 
153 // a previous `MatchedPath` might exist if we're inside a nested Router
append_nested_matched_path(matched_path: &Arc<str>, extensions: &http::Extensions) -> Arc<str>154 fn append_nested_matched_path(matched_path: &Arc<str>, extensions: &http::Extensions) -> Arc<str> {
155     if let Some(previous) = extensions
156         .get::<MatchedPath>()
157         .map(|matched_path| matched_path.as_str())
158         .or_else(|| Some(&extensions.get::<MatchedNestedPath>()?.0))
159     {
160         let previous = previous
161             .strip_suffix(NEST_TAIL_PARAM_CAPTURE)
162             .unwrap_or(previous);
163 
164         let matched_path = format!("{previous}{matched_path}");
165         matched_path.into()
166     } else {
167         Arc::clone(matched_path)
168     }
169 }
170 
171 #[cfg(test)]
172 mod tests {
173     use super::*;
174     use crate::{
175         body::Body,
176         handler::HandlerWithoutStateExt,
177         middleware::map_request,
178         routing::{any, get},
179         test_helpers::*,
180         Router,
181     };
182     use http::{Request, StatusCode};
183 
184     #[crate::test]
extracting_on_handler()185     async fn extracting_on_handler() {
186         let app = Router::new().route(
187             "/:a",
188             get(|path: MatchedPath| async move { path.as_str().to_owned() }),
189         );
190 
191         let client = TestClient::new(app);
192 
193         let res = client.get("/foo").send().await;
194         assert_eq!(res.text().await, "/:a");
195     }
196 
197     #[crate::test]
extracting_on_handler_in_nested_router()198     async fn extracting_on_handler_in_nested_router() {
199         let app = Router::new().nest(
200             "/:a",
201             Router::new().route(
202                 "/:b",
203                 get(|path: MatchedPath| async move { path.as_str().to_owned() }),
204             ),
205         );
206 
207         let client = TestClient::new(app);
208 
209         let res = client.get("/foo/bar").send().await;
210         assert_eq!(res.text().await, "/:a/:b");
211     }
212 
213     #[crate::test]
extracting_on_handler_in_deeply_nested_router()214     async fn extracting_on_handler_in_deeply_nested_router() {
215         let app = Router::new().nest(
216             "/:a",
217             Router::new().nest(
218                 "/:b",
219                 Router::new().route(
220                     "/:c",
221                     get(|path: MatchedPath| async move { path.as_str().to_owned() }),
222                 ),
223             ),
224         );
225 
226         let client = TestClient::new(app);
227 
228         let res = client.get("/foo/bar/baz").send().await;
229         assert_eq!(res.text().await, "/:a/:b/:c");
230     }
231 
232     #[crate::test]
cannot_extract_nested_matched_path_in_middleware()233     async fn cannot_extract_nested_matched_path_in_middleware() {
234         async fn extract_matched_path<B>(
235             matched_path: Option<MatchedPath>,
236             req: Request<B>,
237         ) -> Request<B> {
238             assert!(matched_path.is_none());
239             req
240         }
241 
242         let app = Router::new()
243             .nest_service("/:a", Router::new().route("/:b", get(|| async move {})))
244             .layer(map_request(extract_matched_path));
245 
246         let client = TestClient::new(app);
247 
248         let res = client.get("/foo/bar").send().await;
249         assert_eq!(res.status(), StatusCode::OK);
250     }
251 
252     #[crate::test]
can_extract_nested_matched_path_in_middleware_using_nest()253     async fn can_extract_nested_matched_path_in_middleware_using_nest() {
254         async fn extract_matched_path<B>(
255             matched_path: Option<MatchedPath>,
256             req: Request<B>,
257         ) -> Request<B> {
258             assert_eq!(matched_path.unwrap().as_str(), "/:a/:b");
259             req
260         }
261 
262         let app = Router::new()
263             .nest("/:a", Router::new().route("/:b", get(|| async move {})))
264             .layer(map_request(extract_matched_path));
265 
266         let client = TestClient::new(app);
267 
268         let res = client.get("/foo/bar").send().await;
269         assert_eq!(res.status(), StatusCode::OK);
270     }
271 
272     #[crate::test]
cannot_extract_nested_matched_path_in_middleware_via_extension()273     async fn cannot_extract_nested_matched_path_in_middleware_via_extension() {
274         async fn assert_no_matched_path<B>(req: Request<B>) -> Request<B> {
275             assert!(req.extensions().get::<MatchedPath>().is_none());
276             req
277         }
278 
279         let app = Router::new()
280             .nest_service("/:a", Router::new().route("/:b", get(|| async move {})))
281             .layer(map_request(assert_no_matched_path));
282 
283         let client = TestClient::new(app);
284 
285         let res = client.get("/foo/bar").send().await;
286         assert_eq!(res.status(), StatusCode::OK);
287     }
288 
289     #[tokio::test]
can_extract_nested_matched_path_in_middleware_via_extension_using_nest()290     async fn can_extract_nested_matched_path_in_middleware_via_extension_using_nest() {
291         async fn assert_matched_path<B>(req: Request<B>) -> Request<B> {
292             assert!(req.extensions().get::<MatchedPath>().is_some());
293             req
294         }
295 
296         let app = Router::new()
297             .nest("/:a", Router::new().route("/:b", get(|| async move {})))
298             .layer(map_request(assert_matched_path));
299 
300         let client = TestClient::new(app);
301 
302         let res = client.get("/foo/bar").send().await;
303         assert_eq!(res.status(), StatusCode::OK);
304     }
305 
306     #[crate::test]
can_extract_nested_matched_path_in_middleware_on_nested_router()307     async fn can_extract_nested_matched_path_in_middleware_on_nested_router() {
308         async fn extract_matched_path<B>(matched_path: MatchedPath, req: Request<B>) -> Request<B> {
309             assert_eq!(matched_path.as_str(), "/:a/:b");
310             req
311         }
312 
313         let app = Router::new().nest(
314             "/:a",
315             Router::new()
316                 .route("/:b", get(|| async move {}))
317                 .layer(map_request(extract_matched_path)),
318         );
319 
320         let client = TestClient::new(app);
321 
322         let res = client.get("/foo/bar").send().await;
323         assert_eq!(res.status(), StatusCode::OK);
324     }
325 
326     #[crate::test]
can_extract_nested_matched_path_in_middleware_on_nested_router_via_extension()327     async fn can_extract_nested_matched_path_in_middleware_on_nested_router_via_extension() {
328         async fn extract_matched_path<B>(req: Request<B>) -> Request<B> {
329             let matched_path = req.extensions().get::<MatchedPath>().unwrap();
330             assert_eq!(matched_path.as_str(), "/:a/:b");
331             req
332         }
333 
334         let app = Router::new().nest(
335             "/:a",
336             Router::new()
337                 .route("/:b", get(|| async move {}))
338                 .layer(map_request(extract_matched_path)),
339         );
340 
341         let client = TestClient::new(app);
342 
343         let res = client.get("/foo/bar").send().await;
344         assert_eq!(res.status(), StatusCode::OK);
345     }
346 
347     #[crate::test]
extracting_on_nested_handler()348     async fn extracting_on_nested_handler() {
349         async fn handler(path: Option<MatchedPath>) {
350             assert!(path.is_none());
351         }
352 
353         let app = Router::new().nest_service("/:a", handler.into_service());
354 
355         let client = TestClient::new(app);
356 
357         let res = client.get("/foo/bar").send().await;
358         assert_eq!(res.status(), StatusCode::OK);
359     }
360 
361     // https://github.com/tokio-rs/axum/issues/1579
362     #[crate::test]
doesnt_panic_if_router_called_from_wildcard_route()363     async fn doesnt_panic_if_router_called_from_wildcard_route() {
364         use tower::ServiceExt;
365 
366         let app = Router::new().route(
367             "/*path",
368             any(|req: Request<Body>| {
369                 Router::new()
370                     .nest("/", Router::new().route("/foo", get(|| async {})))
371                     .oneshot(req)
372             }),
373         );
374 
375         let client = TestClient::new(app);
376 
377         let res = client.get("/foo").send().await;
378         assert_eq!(res.status(), StatusCode::OK);
379     }
380 
381     #[crate::test]
cant_extract_in_fallback()382     async fn cant_extract_in_fallback() {
383         async fn handler(path: Option<MatchedPath>, req: Request<Body>) {
384             assert!(path.is_none());
385             assert!(req.extensions().get::<MatchedPath>().is_none());
386         }
387 
388         let app = Router::new().fallback(handler);
389 
390         let client = TestClient::new(app);
391 
392         let res = client.get("/foo/bar").send().await;
393         assert_eq!(res.status(), StatusCode::OK);
394     }
395 }
396