1 use crate::body::HttpBody;
2 use axum_core::response::IntoResponse;
3 use http::Request;
4 use matchit::MatchError;
5 use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
6 use tower_layer::Layer;
7 use tower_service::Service;
8 
9 use super::{
10     future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
11     MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
12 };
13 
14 pub(super) struct PathRouter<S, B, const IS_FALLBACK: bool> {
15     routes: HashMap<RouteId, Endpoint<S, B>>,
16     node: Arc<Node>,
17     prev_route_id: RouteId,
18 }
19 
20 impl<S, B> PathRouter<S, B, true>
21 where
22     B: HttpBody + Send + 'static,
23     S: Clone + Send + Sync + 'static,
24 {
new_fallback() -> Self25     pub(super) fn new_fallback() -> Self {
26         let mut this = Self::default();
27         this.set_fallback(Endpoint::Route(Route::new(NotFound)));
28         this
29     }
30 
set_fallback(&mut self, endpoint: Endpoint<S, B>)31     pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S, B>) {
32         self.replace_endpoint("/", endpoint.clone());
33         self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
34     }
35 }
36 
37 impl<S, B, const IS_FALLBACK: bool> PathRouter<S, B, IS_FALLBACK>
38 where
39     B: HttpBody + Send + 'static,
40     S: Clone + Send + Sync + 'static,
41 {
route( &mut self, path: &str, method_router: MethodRouter<S, B>, ) -> Result<(), Cow<'static, str>>42     pub(super) fn route(
43         &mut self,
44         path: &str,
45         method_router: MethodRouter<S, B>,
46     ) -> Result<(), Cow<'static, str>> {
47         fn validate_path(path: &str) -> Result<(), &'static str> {
48             if path.is_empty() {
49                 return Err("Paths must start with a `/`. Use \"/\" for root routes");
50             } else if !path.starts_with('/') {
51                 return Err("Paths must start with a `/`");
52             }
53 
54             Ok(())
55         }
56 
57         validate_path(path)?;
58 
59         let id = self.next_route_id();
60 
61         let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
62             .node
63             .path_to_route_id
64             .get(path)
65             .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
66         {
67             // if we're adding a new `MethodRouter` to a route that already has one just
68             // merge them. This makes `.route("/", get(_)).route("/", post(_))` work
69             let service = Endpoint::MethodRouter(
70                 prev_method_router
71                     .clone()
72                     .merge_for_path(Some(path), method_router),
73             );
74             self.routes.insert(route_id, service);
75             return Ok(());
76         } else {
77             Endpoint::MethodRouter(method_router)
78         };
79 
80         self.set_node(path, id)?;
81         self.routes.insert(id, endpoint);
82 
83         Ok(())
84     }
85 
route_service<T>( &mut self, path: &str, service: T, ) -> Result<(), Cow<'static, str>> where T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static,86     pub(super) fn route_service<T>(
87         &mut self,
88         path: &str,
89         service: T,
90     ) -> Result<(), Cow<'static, str>>
91     where
92         T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
93         T::Response: IntoResponse,
94         T::Future: Send + 'static,
95     {
96         self.route_endpoint(path, Endpoint::Route(Route::new(service)))
97     }
98 
route_endpoint( &mut self, path: &str, endpoint: Endpoint<S, B>, ) -> Result<(), Cow<'static, str>>99     pub(super) fn route_endpoint(
100         &mut self,
101         path: &str,
102         endpoint: Endpoint<S, B>,
103     ) -> Result<(), Cow<'static, str>> {
104         if path.is_empty() {
105             return Err("Paths must start with a `/`. Use \"/\" for root routes".into());
106         } else if !path.starts_with('/') {
107             return Err("Paths must start with a `/`".into());
108         }
109 
110         let id = self.next_route_id();
111         self.set_node(path, id)?;
112         self.routes.insert(id, endpoint);
113 
114         Ok(())
115     }
116 
set_node(&mut self, path: &str, id: RouteId) -> Result<(), String>117     fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> {
118         let mut node =
119             Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone());
120         if let Err(err) = node.insert(path, id) {
121             return Err(format!("Invalid route {path:?}: {err}"));
122         }
123         self.node = Arc::new(node);
124         Ok(())
125     }
126 
merge( &mut self, other: PathRouter<S, B, IS_FALLBACK>, ) -> Result<(), Cow<'static, str>>127     pub(super) fn merge(
128         &mut self,
129         other: PathRouter<S, B, IS_FALLBACK>,
130     ) -> Result<(), Cow<'static, str>> {
131         let PathRouter {
132             routes,
133             node,
134             prev_route_id: _,
135         } = other;
136 
137         for (id, route) in routes {
138             let path = node
139                 .route_id_to_path
140                 .get(&id)
141                 .expect("no path for route id. This is a bug in axum. Please file an issue");
142 
143             if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
144                 // when merging two routers it doesn't matter if you do `a.merge(b)` or
145                 // `b.merge(a)`. This must also be true for fallbacks.
146                 //
147                 // However all fallback routers will have routes for `/` and `/*` so when merging
148                 // we have to ignore the top level fallbacks on one side otherwise we get
149                 // conflicts.
150                 //
151                 // `Router::merge` makes sure that when merging fallbacks `other` always has the
152                 // fallback we want to keep. It panics if both routers have a custom fallback. Thus
153                 // it is always okay to ignore one fallback and `Router::merge` also makes sure the
154                 // one we can ignore is that of `self`.
155                 self.replace_endpoint(path, route);
156             } else {
157                 match route {
158                     Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
159                     Endpoint::Route(route) => self.route_service(path, route)?,
160                 }
161             }
162         }
163 
164         Ok(())
165     }
166 
nest( &mut self, path: &str, router: PathRouter<S, B, IS_FALLBACK>, ) -> Result<(), Cow<'static, str>>167     pub(super) fn nest(
168         &mut self,
169         path: &str,
170         router: PathRouter<S, B, IS_FALLBACK>,
171     ) -> Result<(), Cow<'static, str>> {
172         let prefix = validate_nest_path(path);
173 
174         let PathRouter {
175             routes,
176             node,
177             prev_route_id: _,
178         } = router;
179 
180         for (id, endpoint) in routes {
181             let inner_path = node
182                 .route_id_to_path
183                 .get(&id)
184                 .expect("no path for route id. This is a bug in axum. Please file an issue");
185 
186             let path = path_for_nested_route(prefix, inner_path);
187 
188             match endpoint.layer(StripPrefix::layer(prefix)) {
189                 Endpoint::MethodRouter(method_router) => {
190                     self.route(&path, method_router)?;
191                 }
192                 Endpoint::Route(route) => {
193                     self.route_endpoint(&path, Endpoint::Route(route))?;
194                 }
195             }
196         }
197 
198         Ok(())
199     }
200 
nest_service<T>(&mut self, path: &str, svc: T) -> Result<(), Cow<'static, str>> where T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static,201     pub(super) fn nest_service<T>(&mut self, path: &str, svc: T) -> Result<(), Cow<'static, str>>
202     where
203         T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
204         T::Response: IntoResponse,
205         T::Future: Send + 'static,
206     {
207         let path = validate_nest_path(path);
208         let prefix = path;
209 
210         let path = if path.ends_with('/') {
211             format!("{path}*{NEST_TAIL_PARAM}")
212         } else {
213             format!("{path}/*{NEST_TAIL_PARAM}")
214         };
215 
216         let endpoint = Endpoint::Route(Route::new(StripPrefix::new(svc, prefix)));
217 
218         self.route_endpoint(&path, endpoint.clone())?;
219 
220         // `/*rest` is not matched by `/` so we need to also register a router at the
221         // prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself
222         // wouldn't match, which it should
223         self.route_endpoint(prefix, endpoint.clone())?;
224         if !prefix.ends_with('/') {
225             // same goes for `/foo/`, that should also match
226             self.route_endpoint(&format!("{prefix}/"), endpoint)?;
227         }
228 
229         Ok(())
230     }
231 
layer<L, NewReqBody>(self, layer: L) -> PathRouter<S, NewReqBody, IS_FALLBACK> where L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, NewReqBody: HttpBody + 'static,232     pub(super) fn layer<L, NewReqBody>(self, layer: L) -> PathRouter<S, NewReqBody, IS_FALLBACK>
233     where
234         L: Layer<Route<B>> + Clone + Send + 'static,
235         L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
236         <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
237         <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
238         <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
239         NewReqBody: HttpBody + 'static,
240     {
241         let routes = self
242             .routes
243             .into_iter()
244             .map(|(id, endpoint)| {
245                 let route = endpoint.layer(layer.clone());
246                 (id, route)
247             })
248             .collect();
249 
250         PathRouter {
251             routes,
252             node: self.node,
253             prev_route_id: self.prev_route_id,
254         }
255     }
256 
257     #[track_caller]
route_layer<L>(self, layer: L) -> Self where L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<B>> + Clone + Send + 'static, <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request<B>>>::Future: Send + 'static,258     pub(super) fn route_layer<L>(self, layer: L) -> Self
259     where
260         L: Layer<Route<B>> + Clone + Send + 'static,
261         L::Service: Service<Request<B>> + Clone + Send + 'static,
262         <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
263         <L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
264         <L::Service as Service<Request<B>>>::Future: Send + 'static,
265     {
266         if self.routes.is_empty() {
267             panic!(
268                 "Adding a route_layer before any routes is a no-op. \
269                  Add the routes you want the layer to apply to first."
270             );
271         }
272 
273         let routes = self
274             .routes
275             .into_iter()
276             .map(|(id, endpoint)| {
277                 let route = endpoint.layer(layer.clone());
278                 (id, route)
279             })
280             .collect();
281 
282         PathRouter {
283             routes,
284             node: self.node,
285             prev_route_id: self.prev_route_id,
286         }
287     }
288 
with_state<S2>(self, state: S) -> PathRouter<S2, B, IS_FALLBACK>289     pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, B, IS_FALLBACK> {
290         let routes = self
291             .routes
292             .into_iter()
293             .map(|(id, endpoint)| {
294                 let endpoint: Endpoint<S2, B> = match endpoint {
295                     Endpoint::MethodRouter(method_router) => {
296                         Endpoint::MethodRouter(method_router.with_state(state.clone()))
297                     }
298                     Endpoint::Route(route) => Endpoint::Route(route),
299                 };
300                 (id, endpoint)
301             })
302             .collect();
303 
304         PathRouter {
305             routes,
306             node: self.node,
307             prev_route_id: self.prev_route_id,
308         }
309     }
310 
call_with_state( &mut self, mut req: Request<B>, state: S, ) -> Result<RouteFuture<B, Infallible>, (Request<B>, S)>311     pub(super) fn call_with_state(
312         &mut self,
313         mut req: Request<B>,
314         state: S,
315     ) -> Result<RouteFuture<B, Infallible>, (Request<B>, S)> {
316         #[cfg(feature = "original-uri")]
317         {
318             use crate::extract::OriginalUri;
319 
320             if req.extensions().get::<OriginalUri>().is_none() {
321                 let original_uri = OriginalUri(req.uri().clone());
322                 req.extensions_mut().insert(original_uri);
323             }
324         }
325 
326         let path = req.uri().path().to_owned();
327 
328         match self.node.at(&path) {
329             Ok(match_) => {
330                 let id = *match_.value;
331 
332                 if !IS_FALLBACK {
333                     #[cfg(feature = "matched-path")]
334                     crate::extract::matched_path::set_matched_path_for_request(
335                         id,
336                         &self.node.route_id_to_path,
337                         req.extensions_mut(),
338                     );
339                 }
340 
341                 url_params::insert_url_params(req.extensions_mut(), match_.params);
342 
343                 let endpont = self
344                     .routes
345                     .get_mut(&id)
346                     .expect("no route for id. This is a bug in axum. Please file an issue");
347 
348                 match endpont {
349                     Endpoint::MethodRouter(method_router) => {
350                         Ok(method_router.call_with_state(req, state))
351                     }
352                     Endpoint::Route(route) => Ok(route.clone().call(req)),
353                 }
354             }
355             // explicitly handle all variants in case matchit adds
356             // new ones we need to handle differently
357             Err(
358                 MatchError::NotFound
359                 | MatchError::ExtraTrailingSlash
360                 | MatchError::MissingTrailingSlash,
361             ) => Err((req, state)),
362         }
363     }
364 
replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S, B>)365     pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S, B>) {
366         match self.node.at(path) {
367             Ok(match_) => {
368                 let id = *match_.value;
369                 self.routes.insert(id, endpoint);
370             }
371             Err(_) => self
372                 .route_endpoint(path, endpoint)
373                 .expect("path wasn't matched so endpoint shouldn't exist"),
374         }
375     }
376 
next_route_id(&mut self) -> RouteId377     fn next_route_id(&mut self) -> RouteId {
378         let next_id = self
379             .prev_route_id
380             .0
381             .checked_add(1)
382             .expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
383         self.prev_route_id = RouteId(next_id);
384         self.prev_route_id
385     }
386 }
387 
388 impl<B, S, const IS_FALLBACK: bool> Default for PathRouter<S, B, IS_FALLBACK> {
default() -> Self389     fn default() -> Self {
390         Self {
391             routes: Default::default(),
392             node: Default::default(),
393             prev_route_id: RouteId(0),
394         }
395     }
396 }
397 
398 impl<S, B, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, B, IS_FALLBACK> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result399     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
400         f.debug_struct("PathRouter")
401             .field("routes", &self.routes)
402             .field("node", &self.node)
403             .finish()
404     }
405 }
406 
407 impl<S, B, const IS_FALLBACK: bool> Clone for PathRouter<S, B, IS_FALLBACK> {
clone(&self) -> Self408     fn clone(&self) -> Self {
409         Self {
410             routes: self.routes.clone(),
411             node: self.node.clone(),
412             prev_route_id: self.prev_route_id,
413         }
414     }
415 }
416 
417 /// Wrapper around `matchit::Router` that supports merging two `Router`s.
418 #[derive(Clone, Default)]
419 struct Node {
420     inner: matchit::Router<RouteId>,
421     route_id_to_path: HashMap<RouteId, Arc<str>>,
422     path_to_route_id: HashMap<Arc<str>, RouteId>,
423 }
424 
425 impl Node {
insert( &mut self, path: impl Into<String>, val: RouteId, ) -> Result<(), matchit::InsertError>426     fn insert(
427         &mut self,
428         path: impl Into<String>,
429         val: RouteId,
430     ) -> Result<(), matchit::InsertError> {
431         let path = path.into();
432 
433         self.inner.insert(&path, val)?;
434 
435         let shared_path: Arc<str> = path.into();
436         self.route_id_to_path.insert(val, shared_path.clone());
437         self.path_to_route_id.insert(shared_path, val);
438 
439         Ok(())
440     }
441 
at<'n, 'p>( &'n self, path: &'p str, ) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError>442     fn at<'n, 'p>(
443         &'n self,
444         path: &'p str,
445     ) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError> {
446         self.inner.at(path)
447     }
448 }
449 
450 impl fmt::Debug for Node {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result451     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452         f.debug_struct("Node")
453             .field("paths", &self.route_id_to_path)
454             .finish()
455     }
456 }
457 
458 #[track_caller]
validate_nest_path(path: &str) -> &str459 fn validate_nest_path(path: &str) -> &str {
460     if path.is_empty() {
461         // nesting at `""` and `"/"` should mean the same thing
462         return "/";
463     }
464 
465     if path.contains('*') {
466         panic!("Invalid route: nested routes cannot contain wildcards (*)");
467     }
468 
469     path
470 }
471 
path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str>472 pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> {
473     debug_assert!(prefix.starts_with('/'));
474     debug_assert!(path.starts_with('/'));
475 
476     if prefix.ends_with('/') {
477         format!("{prefix}{}", path.trim_start_matches('/')).into()
478     } else if path == "/" {
479         prefix.into()
480     } else {
481         format!("{prefix}{path}").into()
482     }
483 }
484