use crate::body::HttpBody; use axum_core::response::IntoResponse; use http::Request; use matchit::MatchError; use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc}; use tower_layer::Layer; use tower_service::Service; use super::{ future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM, }; pub(super) struct PathRouter { routes: HashMap>, node: Arc, prev_route_id: RouteId, } impl PathRouter where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { pub(super) fn new_fallback() -> Self { let mut this = Self::default(); this.set_fallback(Endpoint::Route(Route::new(NotFound))); this } pub(super) fn set_fallback(&mut self, endpoint: Endpoint) { self.replace_endpoint("/", endpoint.clone()); self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint); } } impl PathRouter where B: HttpBody + Send + 'static, S: Clone + Send + Sync + 'static, { pub(super) fn route( &mut self, path: &str, method_router: MethodRouter, ) -> Result<(), Cow<'static, str>> { fn validate_path(path: &str) -> Result<(), &'static str> { if path.is_empty() { return Err("Paths must start with a `/`. Use \"/\" for root routes"); } else if !path.starts_with('/') { return Err("Paths must start with a `/`"); } Ok(()) } validate_path(path)?; let id = self.next_route_id(); let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self .node .path_to_route_id .get(path) .and_then(|route_id| self.routes.get(route_id).map(|svc| (*route_id, svc))) { // if we're adding a new `MethodRouter` to a route that already has one just // merge them. This makes `.route("/", get(_)).route("/", post(_))` work let service = Endpoint::MethodRouter( prev_method_router .clone() .merge_for_path(Some(path), method_router), ); self.routes.insert(route_id, service); return Ok(()); } else { Endpoint::MethodRouter(method_router) }; self.set_node(path, id)?; self.routes.insert(id, endpoint); Ok(()) } pub(super) fn route_service( &mut self, path: &str, service: T, ) -> Result<(), Cow<'static, str>> where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { self.route_endpoint(path, Endpoint::Route(Route::new(service))) } pub(super) fn route_endpoint( &mut self, path: &str, endpoint: Endpoint, ) -> Result<(), Cow<'static, str>> { if path.is_empty() { return Err("Paths must start with a `/`. Use \"/\" for root routes".into()); } else if !path.starts_with('/') { return Err("Paths must start with a `/`".into()); } let id = self.next_route_id(); self.set_node(path, id)?; self.routes.insert(id, endpoint); Ok(()) } fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> { let mut node = Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone()); if let Err(err) = node.insert(path, id) { return Err(format!("Invalid route {path:?}: {err}")); } self.node = Arc::new(node); Ok(()) } pub(super) fn merge( &mut self, other: PathRouter, ) -> Result<(), Cow<'static, str>> { let PathRouter { routes, node, prev_route_id: _, } = other; for (id, route) in routes { let path = node .route_id_to_path .get(&id) .expect("no path for route id. This is a bug in axum. Please file an issue"); if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) { // when merging two routers it doesn't matter if you do `a.merge(b)` or // `b.merge(a)`. This must also be true for fallbacks. // // However all fallback routers will have routes for `/` and `/*` so when merging // we have to ignore the top level fallbacks on one side otherwise we get // conflicts. // // `Router::merge` makes sure that when merging fallbacks `other` always has the // fallback we want to keep. It panics if both routers have a custom fallback. Thus // it is always okay to ignore one fallback and `Router::merge` also makes sure the // one we can ignore is that of `self`. self.replace_endpoint(path, route); } else { match route { Endpoint::MethodRouter(method_router) => self.route(path, method_router)?, Endpoint::Route(route) => self.route_service(path, route)?, } } } Ok(()) } pub(super) fn nest( &mut self, path: &str, router: PathRouter, ) -> Result<(), Cow<'static, str>> { let prefix = validate_nest_path(path); let PathRouter { routes, node, prev_route_id: _, } = router; for (id, endpoint) in routes { let inner_path = node .route_id_to_path .get(&id) .expect("no path for route id. This is a bug in axum. Please file an issue"); let path = path_for_nested_route(prefix, inner_path); match endpoint.layer(StripPrefix::layer(prefix)) { Endpoint::MethodRouter(method_router) => { self.route(&path, method_router)?; } Endpoint::Route(route) => { self.route_endpoint(&path, Endpoint::Route(route))?; } } } Ok(()) } pub(super) fn nest_service(&mut self, path: &str, svc: T) -> Result<(), Cow<'static, str>> where T: Service, Error = Infallible> + Clone + Send + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { let path = validate_nest_path(path); let prefix = path; let path = if path.ends_with('/') { format!("{path}*{NEST_TAIL_PARAM}") } else { format!("{path}/*{NEST_TAIL_PARAM}") }; let endpoint = Endpoint::Route(Route::new(StripPrefix::new(svc, prefix))); self.route_endpoint(&path, endpoint.clone())?; // `/*rest` is not matched by `/` so we need to also register a router at the // prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself // wouldn't match, which it should self.route_endpoint(prefix, endpoint.clone())?; if !prefix.ends_with('/') { // same goes for `/foo/`, that should also match self.route_endpoint(&format!("{prefix}/"), endpoint)?; } Ok(()) } pub(super) fn layer(self, layer: L) -> PathRouter where L: Layer> + Clone + Send + 'static, L::Service: Service> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Error: Into + 'static, >>::Future: Send + 'static, NewReqBody: HttpBody + 'static, { let routes = self .routes .into_iter() .map(|(id, endpoint)| { let route = endpoint.layer(layer.clone()); (id, route) }) .collect(); PathRouter { routes, node: self.node, prev_route_id: self.prev_route_id, } } #[track_caller] pub(super) fn route_layer(self, layer: L) -> Self where L: Layer> + Clone + Send + 'static, L::Service: Service> + Clone + Send + 'static, >>::Response: IntoResponse + 'static, >>::Error: Into + 'static, >>::Future: Send + 'static, { if self.routes.is_empty() { panic!( "Adding a route_layer before any routes is a no-op. \ Add the routes you want the layer to apply to first." ); } let routes = self .routes .into_iter() .map(|(id, endpoint)| { let route = endpoint.layer(layer.clone()); (id, route) }) .collect(); PathRouter { routes, node: self.node, prev_route_id: self.prev_route_id, } } pub(super) fn with_state(self, state: S) -> PathRouter { let routes = self .routes .into_iter() .map(|(id, endpoint)| { let endpoint: Endpoint = match endpoint { Endpoint::MethodRouter(method_router) => { Endpoint::MethodRouter(method_router.with_state(state.clone())) } Endpoint::Route(route) => Endpoint::Route(route), }; (id, endpoint) }) .collect(); PathRouter { routes, node: self.node, prev_route_id: self.prev_route_id, } } pub(super) fn call_with_state( &mut self, mut req: Request, state: S, ) -> Result, (Request, S)> { #[cfg(feature = "original-uri")] { use crate::extract::OriginalUri; if req.extensions().get::().is_none() { let original_uri = OriginalUri(req.uri().clone()); req.extensions_mut().insert(original_uri); } } let path = req.uri().path().to_owned(); match self.node.at(&path) { Ok(match_) => { let id = *match_.value; if !IS_FALLBACK { #[cfg(feature = "matched-path")] crate::extract::matched_path::set_matched_path_for_request( id, &self.node.route_id_to_path, req.extensions_mut(), ); } url_params::insert_url_params(req.extensions_mut(), match_.params); let endpont = self .routes .get_mut(&id) .expect("no route for id. This is a bug in axum. Please file an issue"); match endpont { Endpoint::MethodRouter(method_router) => { Ok(method_router.call_with_state(req, state)) } Endpoint::Route(route) => Ok(route.clone().call(req)), } } // explicitly handle all variants in case matchit adds // new ones we need to handle differently Err( MatchError::NotFound | MatchError::ExtraTrailingSlash | MatchError::MissingTrailingSlash, ) => Err((req, state)), } } pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint) { match self.node.at(path) { Ok(match_) => { let id = *match_.value; self.routes.insert(id, endpoint); } Err(_) => self .route_endpoint(path, endpoint) .expect("path wasn't matched so endpoint shouldn't exist"), } } fn next_route_id(&mut self) -> RouteId { let next_id = self .prev_route_id .0 .checked_add(1) .expect("Over `u32::MAX` routes created. If you need this, please file an issue."); self.prev_route_id = RouteId(next_id); self.prev_route_id } } impl Default for PathRouter { fn default() -> Self { Self { routes: Default::default(), node: Default::default(), prev_route_id: RouteId(0), } } } impl fmt::Debug for PathRouter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PathRouter") .field("routes", &self.routes) .field("node", &self.node) .finish() } } impl Clone for PathRouter { fn clone(&self) -> Self { Self { routes: self.routes.clone(), node: self.node.clone(), prev_route_id: self.prev_route_id, } } } /// Wrapper around `matchit::Router` that supports merging two `Router`s. #[derive(Clone, Default)] struct Node { inner: matchit::Router, route_id_to_path: HashMap>, path_to_route_id: HashMap, RouteId>, } impl Node { fn insert( &mut self, path: impl Into, val: RouteId, ) -> Result<(), matchit::InsertError> { let path = path.into(); self.inner.insert(&path, val)?; let shared_path: Arc = path.into(); self.route_id_to_path.insert(val, shared_path.clone()); self.path_to_route_id.insert(shared_path, val); Ok(()) } fn at<'n, 'p>( &'n self, path: &'p str, ) -> Result, MatchError> { self.inner.at(path) } } impl fmt::Debug for Node { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Node") .field("paths", &self.route_id_to_path) .finish() } } #[track_caller] fn validate_nest_path(path: &str) -> &str { if path.is_empty() { // nesting at `""` and `"/"` should mean the same thing return "/"; } if path.contains('*') { panic!("Invalid route: nested routes cannot contain wildcards (*)"); } path } pub(crate) fn path_for_nested_route<'a>(prefix: &'a str, path: &'a str) -> Cow<'a, str> { debug_assert!(prefix.starts_with('/')); debug_assert!(path.starts_with('/')); if prefix.ends_with('/') { format!("{prefix}{}", path.trim_start_matches('/')).into() } else if path == "/" { prefix.into() } else { format!("{prefix}{path}").into() } }