1 use crate::body::HttpBody;
2 use axum_core::response::IntoResponse;
3 use http::Request;
4 use matchit::MatchError;
5 use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc};
6 use tower_layer::Layer;
7 use tower_service::Service;
8
9 use super::{
10 future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
11 MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
12 };
13
14 pub(super) struct PathRouter<S, B, const IS_FALLBACK: bool> {
15 routes: HashMap<RouteId, Endpoint<S, B>>,
16 node: Arc<Node>,
17 prev_route_id: RouteId,
18 }
19
20 impl<S, B> PathRouter<S, B, true>
21 where
22 B: HttpBody + Send + 'static,
23 S: Clone + Send + Sync + 'static,
24 {
new_fallback() -> Self25 pub(super) fn new_fallback() -> Self {
26 let mut this = Self::default();
27 this.set_fallback(Endpoint::Route(Route::new(NotFound)));
28 this
29 }
30
set_fallback(&mut self, endpoint: Endpoint<S, B>)31 pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S, B>) {
32 self.replace_endpoint("/", endpoint.clone());
33 self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
34 }
35 }
36
37 impl<S, B, const IS_FALLBACK: bool> PathRouter<S, B, IS_FALLBACK>
38 where
39 B: HttpBody + Send + 'static,
40 S: Clone + Send + Sync + 'static,
41 {
route( &mut self, path: &str, method_router: MethodRouter<S, B>, ) -> Result<(), Cow<'static, str>>42 pub(super) fn route(
43 &mut self,
44 path: &str,
45 method_router: MethodRouter<S, B>,
46 ) -> Result<(), Cow<'static, str>> {
47 fn validate_path(path: &str) -> Result<(), &'static str> {
48 if path.is_empty() {
49 return Err("Paths must start with a `/`. Use \"/\" for root routes");
50 } else if !path.starts_with('/') {
51 return Err("Paths must start with a `/`");
52 }
53
54 Ok(())
55 }
56
57 validate_path(path)?;
58
59 let id = self.next_route_id();
60
61 let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
62 .node
63 .path_to_route_id
64 .get(path)
65 .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc)))
66 {
67 // if we're adding a new `MethodRouter` to a route that already has one just
68 // merge them. This makes `.route("/", get(_)).route("/", post(_))` work
69 let service = Endpoint::MethodRouter(
70 prev_method_router
71 .clone()
72 .merge_for_path(Some(path), method_router),
73 );
74 self.routes.insert(route_id, service);
75 return Ok(());
76 } else {
77 Endpoint::MethodRouter(method_router)
78 };
79
80 self.set_node(path, id)?;
81 self.routes.insert(id, endpoint);
82
83 Ok(())
84 }
85
route_service<T>( &mut self, path: &str, service: T, ) -> Result<(), Cow<'static, str>> where T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static,86 pub(super) fn route_service<T>(
87 &mut self,
88 path: &str,
89 service: T,
90 ) -> Result<(), Cow<'static, str>>
91 where
92 T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
93 T::Response: IntoResponse,
94 T::Future: Send + 'static,
95 {
96 self.route_endpoint(path, Endpoint::Route(Route::new(service)))
97 }
98
route_endpoint( &mut self, path: &str, endpoint: Endpoint<S, B>, ) -> Result<(), Cow<'static, str>>99 pub(super) fn route_endpoint(
100 &mut self,
101 path: &str,
102 endpoint: Endpoint<S, B>,
103 ) -> Result<(), Cow<'static, str>> {
104 if path.is_empty() {
105 return Err("Paths must start with a `/`. Use \"/\" for root routes".into());
106 } else if !path.starts_with('/') {
107 return Err("Paths must start with a `/`".into());
108 }
109
110 let id = self.next_route_id();
111 self.set_node(path, id)?;
112 self.routes.insert(id, endpoint);
113
114 Ok(())
115 }
116
set_node(&mut self, path: &str, id: RouteId) -> Result<(), String>117 fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> {
118 let mut node =
119 Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone());
120 if let Err(err) = node.insert(path, id) {
121 return Err(format!("Invalid route {path:?}: {err}"));
122 }
123 self.node = Arc::new(node);
124 Ok(())
125 }
126
merge( &mut self, other: PathRouter<S, B, IS_FALLBACK>, ) -> Result<(), Cow<'static, str>>127 pub(super) fn merge(
128 &mut self,
129 other: PathRouter<S, B, IS_FALLBACK>,
130 ) -> Result<(), Cow<'static, str>> {
131 let PathRouter {
132 routes,
133 node,
134 prev_route_id: _,
135 } = other;
136
137 for (id, route) in routes {
138 let path = node
139 .route_id_to_path
140 .get(&id)
141 .expect("no path for route id. This is a bug in axum. Please file an issue");
142
143 if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
144 // when merging two routers it doesn't matter if you do `a.merge(b)` or
145 // `b.merge(a)`. This must also be true for fallbacks.
146 //
147 // However all fallback routers will have routes for `/` and `/*` so when merging
148 // we have to ignore the top level fallbacks on one side otherwise we get
149 // conflicts.
150 //
151 // `Router::merge` makes sure that when merging fallbacks `other` always has the
152 // fallback we want to keep. It panics if both routers have a custom fallback. Thus
153 // it is always okay to ignore one fallback and `Router::merge` also makes sure the
154 // one we can ignore is that of `self`.
155 self.replace_endpoint(path, route);
156 } else {
157 match route {
158 Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
159 Endpoint::Route(route) => self.route_service(path, route)?,
160 }
161 }
162 }
163
164 Ok(())
165 }
166
nest( &mut self, path: &str, router: PathRouter<S, B, IS_FALLBACK>, ) -> Result<(), Cow<'static, str>>167 pub(super) fn nest(
168 &mut self,
169 path: &str,
170 router: PathRouter<S, B, IS_FALLBACK>,
171 ) -> Result<(), Cow<'static, str>> {
172 let prefix = validate_nest_path(path);
173
174 let PathRouter {
175 routes,
176 node,
177 prev_route_id: _,
178 } = router;
179
180 for (id, endpoint) in routes {
181 let inner_path = node
182 .route_id_to_path
183 .get(&id)
184 .expect("no path for route id. This is a bug in axum. Please file an issue");
185
186 let path = path_for_nested_route(prefix, inner_path);
187
188 match endpoint.layer(StripPrefix::layer(prefix)) {
189 Endpoint::MethodRouter(method_router) => {
190 self.route(&path, method_router)?;
191 }
192 Endpoint::Route(route) => {
193 self.route_endpoint(&path, Endpoint::Route(route))?;
194 }
195 }
196 }
197
198 Ok(())
199 }
200
nest_service<T>(&mut self, path: &str, svc: T) -> Result<(), Cow<'static, str>> where T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static,201 pub(super) fn nest_service<T>(&mut self, path: &str, svc: T) -> Result<(), Cow<'static, str>>
202 where
203 T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
204 T::Response: IntoResponse,
205 T::Future: Send + 'static,
206 {
207 let path = validate_nest_path(path);
208 let prefix = path;
209
210 let path = if path.ends_with('/') {
211 format!("{path}*{NEST_TAIL_PARAM}")
212 } else {
213 format!("{path}/*{NEST_TAIL_PARAM}")
214 };
215
216 let endpoint = Endpoint::Route(Route::new(StripPrefix::new(svc, prefix)));
217
218 self.route_endpoint(&path, endpoint.clone())?;
219
220 // `/*rest` is not matched by `/` so we need to also register a router at the
221 // prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself
222 // wouldn't match, which it should
223 self.route_endpoint(prefix, endpoint.clone())?;
224 if !prefix.ends_with('/') {
225 // same goes for `/foo/`, that should also match
226 self.route_endpoint(&format!("{prefix}/"), endpoint)?;
227 }
228
229 Ok(())
230 }
231
layer<L, NewReqBody>(self, layer: L) -> PathRouter<S, NewReqBody, IS_FALLBACK> where L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static, <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static, NewReqBody: HttpBody + 'static,232 pub(super) fn layer<L, NewReqBody>(self, layer: L) -> PathRouter<S, NewReqBody, IS_FALLBACK>
233 where
234 L: Layer<Route<B>> + Clone + Send + 'static,
235 L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
236 <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
237 <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
238 <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
239 NewReqBody: HttpBody + 'static,
240 {
241 let routes = self
242 .routes
243 .into_iter()
244 .map(|(id, endpoint)| {
245 let route = endpoint.layer(layer.clone());
246 (id, route)
247 })
248 .collect();
249
250 PathRouter {
251 routes,
252 node: self.node,
253 prev_route_id: self.prev_route_id,
254 }
255 }
256
257 #[track_caller]
route_layer<L>(self, layer: L) -> Self where L: Layer<Route<B>> + Clone + Send + 'static, L::Service: Service<Request<B>> + Clone + Send + 'static, <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static, <L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static, <L::Service as Service<Request<B>>>::Future: Send + 'static,258 pub(super) fn route_layer<L>(self, layer: L) -> Self
259 where
260 L: Layer<Route<B>> + Clone + Send + 'static,
261 L::Service: Service<Request<B>> + Clone + Send + 'static,
262 <L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
263 <L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
264 <L::Service as Service<Request<B>>>::Future: Send + 'static,
265 {
266 if self.routes.is_empty() {
267 panic!(
268 "Adding a route_layer before any routes is a no-op. \
269 Add the routes you want the layer to apply to first."
270 );
271 }
272
273 let routes = self
274 .routes
275 .into_iter()
276 .map(|(id, endpoint)| {
277 let route = endpoint.layer(layer.clone());
278 (id, route)
279 })
280 .collect();
281
282 PathRouter {
283 routes,
284 node: self.node,
285 prev_route_id: self.prev_route_id,
286 }
287 }
288
with_state<S2>(self, state: S) -> PathRouter<S2, B, IS_FALLBACK>289 pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, B, IS_FALLBACK> {
290 let routes = self
291 .routes
292 .into_iter()
293 .map(|(id, endpoint)| {
294 let endpoint: Endpoint<S2, B> = match endpoint {
295 Endpoint::MethodRouter(method_router) => {
296 Endpoint::MethodRouter(method_router.with_state(state.clone()))
297 }
298 Endpoint::Route(route) => Endpoint::Route(route),
299 };
300 (id, endpoint)
301 })
302 .collect();
303
304 PathRouter {
305 routes,
306 node: self.node,
307 prev_route_id: self.prev_route_id,
308 }
309 }
310
call_with_state( &mut self, mut req: Request<B>, state: S, ) -> Result<RouteFuture<B, Infallible>, (Request<B>, S)>311 pub(super) fn call_with_state(
312 &mut self,
313 mut req: Request<B>,
314 state: S,
315 ) -> Result<RouteFuture<B, Infallible>, (Request<B>, S)> {
316 #[cfg(feature = "original-uri")]
317 {
318 use crate::extract::OriginalUri;
319
320 if req.extensions().get::<OriginalUri>().is_none() {
321 let original_uri = OriginalUri(req.uri().clone());
322 req.extensions_mut().insert(original_uri);
323 }
324 }
325
326 let path = req.uri().path().to_owned();
327
328 match self.node.at(&path) {
329 Ok(match_) => {
330 let id = *match_.value;
331
332 if !IS_FALLBACK {
333 #[cfg(feature = "matched-path")]
334 crate::extract::matched_path::set_matched_path_for_request(
335 id,
336 &self.node.route_id_to_path,
337 req.extensions_mut(),
338 );
339 }
340
341 url_params::insert_url_params(req.extensions_mut(), match_.params);
342
343 let endpont = self
344 .routes
345 .get_mut(&id)
346 .expect("no route for id. This is a bug in axum. Please file an issue");
347
348 match endpont {
349 Endpoint::MethodRouter(method_router) => {
350 Ok(method_router.call_with_state(req, state))
351 }
352 Endpoint::Route(route) => Ok(route.clone().call(req)),
353 }
354 }
355 // explicitly handle all variants in case matchit adds
356 // new ones we need to handle differently
357 Err(
358 MatchError::NotFound
359 | MatchError::ExtraTrailingSlash
360 | MatchError::MissingTrailingSlash,
361 ) => Err((req, state)),
362 }
363 }
364
replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S, B>)365 pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S, B>) {
366 match self.node.at(path) {
367 Ok(match_) => {
368 let id = *match_.value;
369 self.routes.insert(id, endpoint);
370 }
371 Err(_) => self
372 .route_endpoint(path, endpoint)
373 .expect("path wasn't matched so endpoint shouldn't exist"),
374 }
375 }
376
next_route_id(&mut self) -> RouteId377 fn next_route_id(&mut self) -> RouteId {
378 let next_id = self
379 .prev_route_id
380 .0
381 .checked_add(1)
382 .expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
383 self.prev_route_id = RouteId(next_id);
384 self.prev_route_id
385 }
386 }
387
388 impl<B, S, const IS_FALLBACK: bool> Default for PathRouter<S, B, IS_FALLBACK> {
default() -> Self389 fn default() -> Self {
390 Self {
391 routes: Default::default(),
392 node: Default::default(),
393 prev_route_id: RouteId(0),
394 }
395 }
396 }
397
398 impl<S, B, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, B, IS_FALLBACK> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result399 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
400 f.debug_struct("PathRouter")
401 .field("routes", &self.routes)
402 .field("node", &self.node)
403 .finish()
404 }
405 }
406
407 impl<S, B, const IS_FALLBACK: bool> Clone for PathRouter<S, B, IS_FALLBACK> {
clone(&self) -> Self408 fn clone(&self) -> Self {
409 Self {
410 routes: self.routes.clone(),
411 node: self.node.clone(),
412 prev_route_id: self.prev_route_id,
413 }
414 }
415 }
416
417 /// Wrapper around `matchit::Router` that supports merging two `Router`s.
418 #[derive(Clone, Default)]
419 struct Node {
420 inner: matchit::Router<RouteId>,
421 route_id_to_path: HashMap<RouteId, Arc<str>>,
422 path_to_route_id: HashMap<Arc<str>, RouteId>,
423 }
424
425 impl Node {
insert( &mut self, path: impl Into<String>, val: RouteId, ) -> Result<(), matchit::InsertError>426 fn insert(
427 &mut self,
428 path: impl Into<String>,
429 val: RouteId,
430 ) -> Result<(), matchit::InsertError> {
431 let path = path.into();
432
433 self.inner.insert(&path, val)?;
434
435 let shared_path: Arc<str> = path.into();
436 self.route_id_to_path.insert(val, shared_path.clone());
437 self.path_to_route_id.insert(shared_path, val);
438
439 Ok(())
440 }
441
at<'n, 'p>( &'n self, path: &'p str, ) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError>442 fn at<'n, 'p>(
443 &'n self,
444 path: &'p str,
445 ) -> Result<matchit::Match<'n, 'p, &'n RouteId>, MatchError> {
446 self.inner.at(path)
447 }
448 }
449
450 impl fmt::Debug for Node {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result451 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452 f.debug_struct("Node")
453 .field("paths", &self.route_id_to_path)
454 .finish()
455 }
456 }
457
458 #[track_caller]
validate_nest_path(path: &str) -> &str459 fn validate_nest_path(path: &str) -> &str {
460 if path.is_empty() {
461 // nesting at `""` and `"/"` should mean the same thing
462 return "/";
463 }
464
465 if path.contains('*') {
466 panic!("Invalid route: nested routes cannot contain wildcards (*)");
467 }
468
469 path
470 }
471
path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str>472 pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> {
473 debug_assert!(prefix.starts_with('/'));
474 debug_assert!(path.starts_with('/'));
475
476 if prefix.ends_with('/') {
477 format!("{prefix}{}", path.trim_start_matches('/')).into()
478 } else if path == "/" {
479 prefix.into()
480 } else {
481 format!("{prefix}{path}").into()
482 }
483 }
484