use http::{Request, Uri}; use std::{ sync::Arc, task::{Context, Poll}, }; use tower::Layer; use tower_layer::layer_fn; use tower_service::Service; #[derive(Clone)] pub(super) struct StripPrefix { inner: S, prefix: Arc, } impl StripPrefix { pub(super) fn new(inner: S, prefix: &str) -> Self { Self { inner, prefix: prefix.into(), } } pub(super) fn layer(prefix: &str) -> impl Layer + Clone { let prefix = Arc::from(prefix); layer_fn(move |inner| Self { inner, prefix: Arc::clone(&prefix), }) } } impl Service> for StripPrefix where S: Service>, { type Response = S::Response; type Error = S::Error; type Future = S::Future; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request) -> Self::Future { if let Some(new_uri) = strip_prefix(req.uri(), &self.prefix) { *req.uri_mut() = new_uri; } self.inner.call(req) } } fn strip_prefix(uri: &Uri, prefix: &str) -> Option { let path_and_query = uri.path_and_query()?; // Check whether the prefix matches the path and if so how long the matching prefix is. // // For example: // // prefix = /api // path = /api/users // ^^^^ this much is matched and the length is 4. Thus if we chop off the first 4 // characters we get the remainder // // prefix = /api/:version // path = /api/v0/users // ^^^^^^^ this much is matched and the length is 7. let mut matching_prefix_length = Some(0); for item in zip_longest(segments(path_and_query.path()), segments(prefix)) { // count the `/` *matching_prefix_length.as_mut().unwrap() += 1; match item { Item::Both(path_segment, prefix_segment) => { if prefix_segment.starts_with(':') || path_segment == prefix_segment { // the prefix segment is either a param, which matches anything, or // it actually matches the path segment *matching_prefix_length.as_mut().unwrap() += path_segment.len(); } else if prefix_segment.is_empty() { // the prefix ended in a `/` so we got a match. // // For example: // // prefix = /foo/ // path = /foo/bar // // The prefix matches and the new path should be `/bar` break; } else { // the prefix segment didn't match so there is no match matching_prefix_length = None; break; } } // the path had more segments than the prefix but we got a match. // // For example: // // prefix = /foo // path = /foo/bar Item::First(_) => { break; } // the prefix had more segments than the path so there is no match Item::Second(_) => { matching_prefix_length = None; break; } } } // if the prefix matches it will always do so up until a `/`, it cannot match only // part of a segment. Therefore this will always be at a char boundary and `split_at` wont // panic let after_prefix = uri.path().split_at(matching_prefix_length?).1; let new_path_and_query = match (after_prefix.starts_with('/'), path_and_query.query()) { (true, None) => after_prefix.parse().unwrap(), (true, Some(query)) => format!("{after_prefix}?{query}").parse().unwrap(), (false, None) => format!("/{after_prefix}").parse().unwrap(), (false, Some(query)) => format!("/{after_prefix}?{query}").parse().unwrap(), }; let mut parts = uri.clone().into_parts(); parts.path_and_query = Some(new_path_and_query); Some(Uri::from_parts(parts).unwrap()) } fn segments(s: &str) -> impl Iterator { assert!( s.starts_with('/'), "path didn't start with '/'. axum should have caught this higher up." ); s.split('/') // skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"] // otherwise .skip(1) } fn zip_longest(a: I, b: I2) -> impl Iterator> where I: Iterator, I2: Iterator, { let a = a.map(Some).chain(std::iter::repeat_with(|| None)); let b = b.map(Some).chain(std::iter::repeat_with(|| None)); a.zip(b).map_while(|(a, b)| match (a, b) { (Some(a), Some(b)) => Some(Item::Both(a, b)), (Some(a), None) => Some(Item::First(a)), (None, Some(b)) => Some(Item::Second(b)), (None, None) => None, }) } #[derive(Debug)] enum Item { Both(T, T), First(T), Second(T), } #[cfg(test)] mod tests { #[allow(unused_imports)] use super::*; use quickcheck::Arbitrary; use quickcheck_macros::quickcheck; macro_rules! test { ( $name:ident, uri = $uri:literal, prefix = $prefix:literal, expected = $expected:expr, ) => { #[test] fn $name() { let uri = $uri.parse().unwrap(); let new_uri = strip_prefix(&uri, $prefix).map(|uri| uri.to_string()); assert_eq!(new_uri.as_deref(), $expected); } }; } test!(empty, uri = "/", prefix = "/", expected = Some("/"),); test!( single_segment, uri = "/a", prefix = "/a", expected = Some("/"), ); test!( single_segment_root_uri, uri = "/", prefix = "/a", expected = None, ); // the prefix is empty, so removing it should have no effect test!( single_segment_root_prefix, uri = "/a", prefix = "/", expected = Some("/a"), ); test!( single_segment_no_match, uri = "/a", prefix = "/b", expected = None, ); test!( single_segment_trailing_slash, uri = "/a/", prefix = "/a/", expected = Some("/"), ); test!( single_segment_trailing_slash_2, uri = "/a", prefix = "/a/", expected = None, ); test!( single_segment_trailing_slash_3, uri = "/a/", prefix = "/a", expected = Some("/"), ); test!( multi_segment, uri = "/a/b", prefix = "/a", expected = Some("/b"), ); test!( multi_segment_2, uri = "/b/a", prefix = "/a", expected = None, ); test!( multi_segment_3, uri = "/a", prefix = "/a/b", expected = None, ); test!( multi_segment_4, uri = "/a/b", prefix = "/b", expected = None, ); test!( multi_segment_trailing_slash, uri = "/a/b/", prefix = "/a/b/", expected = Some("/"), ); test!( multi_segment_trailing_slash_2, uri = "/a/b", prefix = "/a/b/", expected = None, ); test!( multi_segment_trailing_slash_3, uri = "/a/b/", prefix = "/a/b", expected = Some("/"), ); test!(param_0, uri = "/", prefix = "/:param", expected = Some("/"),); test!( param_1, uri = "/a", prefix = "/:param", expected = Some("/"), ); test!( param_2, uri = "/a/b", prefix = "/:param", expected = Some("/b"), ); test!( param_3, uri = "/b/a", prefix = "/:param", expected = Some("/a"), ); test!( param_4, uri = "/a/b", prefix = "/a/:param", expected = Some("/"), ); test!(param_5, uri = "/b/a", prefix = "/a/:param", expected = None,); test!(param_6, uri = "/a/b", prefix = "/:param/a", expected = None,); test!( param_7, uri = "/b/a", prefix = "/:param/a", expected = Some("/"), ); test!( param_8, uri = "/a/b/c", prefix = "/a/:param/c", expected = Some("/"), ); test!( param_9, uri = "/c/b/a", prefix = "/a/:param/c", expected = None, ); test!( param_10, uri = "/a/", prefix = "/:param", expected = Some("/"), ); test!(param_11, uri = "/a", prefix = "/:param/", expected = None,); test!( param_12, uri = "/a/", prefix = "/:param/", expected = Some("/"), ); test!( param_13, uri = "/a/a", prefix = "/a/", expected = Some("/a"), ); #[quickcheck] fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool { let UriAndPrefix { uri, prefix } = uri_and_prefix; strip_prefix(&uri, &prefix); true } #[derive(Clone, Debug)] struct UriAndPrefix { uri: Uri, prefix: String, } impl Arbitrary for UriAndPrefix { fn arbitrary(g: &mut quickcheck::Gen) -> Self { let mut uri = String::new(); let mut prefix = String::new(); let size = u8_between(1, 20, g); for _ in 0..size { let segment = ascii_alphanumeric(g); uri.push('/'); uri.push_str(&segment); prefix.push('/'); let make_matching_segment = bool::arbitrary(g); let make_capture = bool::arbitrary(g); match (make_matching_segment, make_capture) { (_, true) => { prefix.push_str(":a"); } (true, false) => { prefix.push_str(&segment); } (false, false) => { prefix.push_str(&ascii_alphanumeric(g)); } } } if bool::arbitrary(g) { uri.push('/'); } if bool::arbitrary(g) { prefix.push('/'); } Self { uri: uri.parse().unwrap(), prefix, } } } fn ascii_alphanumeric(g: &mut quickcheck::Gen) -> String { #[derive(Clone)] struct AsciiAlphanumeric(String); impl Arbitrary for AsciiAlphanumeric { fn arbitrary(g: &mut quickcheck::Gen) -> Self { let mut out = String::new(); let size = u8_between(1, 20, g) as usize; while out.len() < size { let c = char::arbitrary(g); if c.is_ascii_alphanumeric() { out.push(c); } } Self(out) } } let out = AsciiAlphanumeric::arbitrary(g).0; assert!(!out.is_empty()); out } fn u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8 { loop { let size = u8::arbitrary(g); if size > lower && size <= upper { break size; } } } }