1 use http::{Request, Uri};
2 use std::{
3     sync::Arc,
4     task::{Context, Poll},
5 };
6 use tower::Layer;
7 use tower_layer::layer_fn;
8 use tower_service::Service;
9 
10 #[derive(Clone)]
11 pub(super) struct StripPrefix<S> {
12     inner: S,
13     prefix: Arc<str>,
14 }
15 
16 impl<S> StripPrefix<S> {
new(inner: S, prefix: &str) -> Self17     pub(super) fn new(inner: S, prefix: &str) -> Self {
18         Self {
19             inner,
20             prefix: prefix.into(),
21         }
22     }
23 
layer(prefix: &str) -> impl Layer<S, Service = Self> + Clone24     pub(super) fn layer(prefix: &str) -> impl Layer<S, Service = Self> + Clone {
25         let prefix = Arc::from(prefix);
26         layer_fn(move |inner| Self {
27             inner,
28             prefix: Arc::clone(&prefix),
29         })
30     }
31 }
32 
33 impl<S, B> Service<Request<B>> for StripPrefix<S>
34 where
35     S: Service<Request<B>>,
36 {
37     type Response = S::Response;
38     type Error = S::Error;
39     type Future = S::Future;
40 
41     #[inline]
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>42     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
43         self.inner.poll_ready(cx)
44     }
45 
call(&mut self, mut req: Request<B>) -> Self::Future46     fn call(&mut self, mut req: Request<B>) -> Self::Future {
47         if let Some(new_uri) = strip_prefix(req.uri(), &self.prefix) {
48             *req.uri_mut() = new_uri;
49         }
50         self.inner.call(req)
51     }
52 }
53 
strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri>54 fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
55     let path_and_query = uri.path_and_query()?;
56 
57     // Check whether the prefix matches the path and if so how long the matching prefix is.
58     //
59     // For example:
60     //
61     // prefix = /api
62     // path   = /api/users
63     //          ^^^^ this much is matched and the length is 4. Thus if we chop off the first 4
64     //          characters we get the remainder
65     //
66     // prefix = /api/:version
67     // path   = /api/v0/users
68     //          ^^^^^^^ this much is matched and the length is 7.
69     let mut matching_prefix_length = Some(0);
70     for item in zip_longest(segments(path_and_query.path()), segments(prefix)) {
71         // count the `/`
72         *matching_prefix_length.as_mut().unwrap() += 1;
73 
74         match item {
75             Item::Both(path_segment, prefix_segment) => {
76                 if prefix_segment.starts_with(':') || path_segment == prefix_segment {
77                     // the prefix segment is either a param, which matches anything, or
78                     // it actually matches the path segment
79                     *matching_prefix_length.as_mut().unwrap() += path_segment.len();
80                 } else if prefix_segment.is_empty() {
81                     // the prefix ended in a `/` so we got a match.
82                     //
83                     // For example:
84                     //
85                     // prefix = /foo/
86                     // path   = /foo/bar
87                     //
88                     // The prefix matches and the new path should be `/bar`
89                     break;
90                 } else {
91                     // the prefix segment didn't match so there is no match
92                     matching_prefix_length = None;
93                     break;
94                 }
95             }
96             // the path had more segments than the prefix but we got a match.
97             //
98             // For example:
99             //
100             // prefix = /foo
101             // path   = /foo/bar
102             Item::First(_) => {
103                 break;
104             }
105             // the prefix had more segments than the path so there is no match
106             Item::Second(_) => {
107                 matching_prefix_length = None;
108                 break;
109             }
110         }
111     }
112 
113     // if the prefix matches it will always do so up until a `/`, it cannot match only
114     // part of a segment. Therefore this will always be at a char boundary and `split_at` wont
115     // panic
116     let after_prefix = uri.path().split_at(matching_prefix_length?).1;
117 
118     let new_path_and_query = match (after_prefix.starts_with('/'), path_and_query.query()) {
119         (true, None) => after_prefix.parse().unwrap(),
120         (true, Some(query)) => format!("{after_prefix}?{query}").parse().unwrap(),
121         (false, None) => format!("/{after_prefix}").parse().unwrap(),
122         (false, Some(query)) => format!("/{after_prefix}?{query}").parse().unwrap(),
123     };
124 
125     let mut parts = uri.clone().into_parts();
126     parts.path_and_query = Some(new_path_and_query);
127 
128     Some(Uri::from_parts(parts).unwrap())
129 }
130 
segments(s: &str) -> impl Iterator<Item = &str>131 fn segments(s: &str) -> impl Iterator<Item = &str> {
132     assert!(
133         s.starts_with('/'),
134         "path didn't start with '/'. axum should have caught this higher up."
135     );
136 
137     s.split('/')
138         // skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"]
139         // otherwise
140         .skip(1)
141 }
142 
zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>> where I: Iterator, I2: Iterator<Item = I::Item>,143 fn zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>>
144 where
145     I: Iterator,
146     I2: Iterator<Item = I::Item>,
147 {
148     let a = a.map(Some).chain(std::iter::repeat_with(|| None));
149     let b = b.map(Some).chain(std::iter::repeat_with(|| None));
150     a.zip(b).map_while(|(a, b)| match (a, b) {
151         (Some(a), Some(b)) => Some(Item::Both(a, b)),
152         (Some(a), None) => Some(Item::First(a)),
153         (None, Some(b)) => Some(Item::Second(b)),
154         (None, None) => None,
155     })
156 }
157 
158 #[derive(Debug)]
159 enum Item<T> {
160     Both(T, T),
161     First(T),
162     Second(T),
163 }
164 
165 #[cfg(test)]
166 mod tests {
167     #[allow(unused_imports)]
168     use super::*;
169     use quickcheck::Arbitrary;
170     use quickcheck_macros::quickcheck;
171 
172     macro_rules! test {
173         (
174             $name:ident,
175             uri = $uri:literal,
176             prefix = $prefix:literal,
177             expected = $expected:expr,
178         ) => {
179             #[test]
180             fn $name() {
181                 let uri = $uri.parse().unwrap();
182                 let new_uri = strip_prefix(&uri, $prefix).map(|uri| uri.to_string());
183                 assert_eq!(new_uri.as_deref(), $expected);
184             }
185         };
186     }
187 
188     test!(empty, uri = "/", prefix = "/", expected = Some("/"),);
189 
190     test!(
191         single_segment,
192         uri = "/a",
193         prefix = "/a",
194         expected = Some("/"),
195     );
196 
197     test!(
198         single_segment_root_uri,
199         uri = "/",
200         prefix = "/a",
201         expected = None,
202     );
203 
204     // the prefix is empty, so removing it should have no effect
205     test!(
206         single_segment_root_prefix,
207         uri = "/a",
208         prefix = "/",
209         expected = Some("/a"),
210     );
211 
212     test!(
213         single_segment_no_match,
214         uri = "/a",
215         prefix = "/b",
216         expected = None,
217     );
218 
219     test!(
220         single_segment_trailing_slash,
221         uri = "/a/",
222         prefix = "/a/",
223         expected = Some("/"),
224     );
225 
226     test!(
227         single_segment_trailing_slash_2,
228         uri = "/a",
229         prefix = "/a/",
230         expected = None,
231     );
232 
233     test!(
234         single_segment_trailing_slash_3,
235         uri = "/a/",
236         prefix = "/a",
237         expected = Some("/"),
238     );
239 
240     test!(
241         multi_segment,
242         uri = "/a/b",
243         prefix = "/a",
244         expected = Some("/b"),
245     );
246 
247     test!(
248         multi_segment_2,
249         uri = "/b/a",
250         prefix = "/a",
251         expected = None,
252     );
253 
254     test!(
255         multi_segment_3,
256         uri = "/a",
257         prefix = "/a/b",
258         expected = None,
259     );
260 
261     test!(
262         multi_segment_4,
263         uri = "/a/b",
264         prefix = "/b",
265         expected = None,
266     );
267 
268     test!(
269         multi_segment_trailing_slash,
270         uri = "/a/b/",
271         prefix = "/a/b/",
272         expected = Some("/"),
273     );
274 
275     test!(
276         multi_segment_trailing_slash_2,
277         uri = "/a/b",
278         prefix = "/a/b/",
279         expected = None,
280     );
281 
282     test!(
283         multi_segment_trailing_slash_3,
284         uri = "/a/b/",
285         prefix = "/a/b",
286         expected = Some("/"),
287     );
288 
289     test!(param_0, uri = "/", prefix = "/:param", expected = Some("/"),);
290 
291     test!(
292         param_1,
293         uri = "/a",
294         prefix = "/:param",
295         expected = Some("/"),
296     );
297 
298     test!(
299         param_2,
300         uri = "/a/b",
301         prefix = "/:param",
302         expected = Some("/b"),
303     );
304 
305     test!(
306         param_3,
307         uri = "/b/a",
308         prefix = "/:param",
309         expected = Some("/a"),
310     );
311 
312     test!(
313         param_4,
314         uri = "/a/b",
315         prefix = "/a/:param",
316         expected = Some("/"),
317     );
318 
319     test!(param_5, uri = "/b/a", prefix = "/a/:param", expected = None,);
320 
321     test!(param_6, uri = "/a/b", prefix = "/:param/a", expected = None,);
322 
323     test!(
324         param_7,
325         uri = "/b/a",
326         prefix = "/:param/a",
327         expected = Some("/"),
328     );
329 
330     test!(
331         param_8,
332         uri = "/a/b/c",
333         prefix = "/a/:param/c",
334         expected = Some("/"),
335     );
336 
337     test!(
338         param_9,
339         uri = "/c/b/a",
340         prefix = "/a/:param/c",
341         expected = None,
342     );
343 
344     test!(
345         param_10,
346         uri = "/a/",
347         prefix = "/:param",
348         expected = Some("/"),
349     );
350 
351     test!(param_11, uri = "/a", prefix = "/:param/", expected = None,);
352 
353     test!(
354         param_12,
355         uri = "/a/",
356         prefix = "/:param/",
357         expected = Some("/"),
358     );
359 
360     test!(
361         param_13,
362         uri = "/a/a",
363         prefix = "/a/",
364         expected = Some("/a"),
365     );
366 
367     #[quickcheck]
does_not_panic(uri_and_prefix: UriAndPrefix) -> bool368     fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool {
369         let UriAndPrefix { uri, prefix } = uri_and_prefix;
370         strip_prefix(&uri, &prefix);
371         true
372     }
373 
374     #[derive(Clone, Debug)]
375     struct UriAndPrefix {
376         uri: Uri,
377         prefix: String,
378     }
379 
380     impl Arbitrary for UriAndPrefix {
arbitrary(g: &mut quickcheck::Gen) -> Self381         fn arbitrary(g: &mut quickcheck::Gen) -> Self {
382             let mut uri = String::new();
383             let mut prefix = String::new();
384 
385             let size = u8_between(1, 20, g);
386 
387             for _ in 0..size {
388                 let segment = ascii_alphanumeric(g);
389 
390                 uri.push('/');
391                 uri.push_str(&segment);
392 
393                 prefix.push('/');
394 
395                 let make_matching_segment = bool::arbitrary(g);
396                 let make_capture = bool::arbitrary(g);
397 
398                 match (make_matching_segment, make_capture) {
399                     (_, true) => {
400                         prefix.push_str(":a");
401                     }
402                     (true, false) => {
403                         prefix.push_str(&segment);
404                     }
405                     (false, false) => {
406                         prefix.push_str(&ascii_alphanumeric(g));
407                     }
408                 }
409             }
410 
411             if bool::arbitrary(g) {
412                 uri.push('/');
413             }
414 
415             if bool::arbitrary(g) {
416                 prefix.push('/');
417             }
418 
419             Self {
420                 uri: uri.parse().unwrap(),
421                 prefix,
422             }
423         }
424     }
425 
ascii_alphanumeric(g: &mut quickcheck::Gen) -> String426     fn ascii_alphanumeric(g: &mut quickcheck::Gen) -> String {
427         #[derive(Clone)]
428         struct AsciiAlphanumeric(String);
429 
430         impl Arbitrary for AsciiAlphanumeric {
431             fn arbitrary(g: &mut quickcheck::Gen) -> Self {
432                 let mut out = String::new();
433 
434                 let size = u8_between(1, 20, g) as usize;
435 
436                 while out.len() < size {
437                     let c = char::arbitrary(g);
438                     if c.is_ascii_alphanumeric() {
439                         out.push(c);
440                     }
441                 }
442                 Self(out)
443             }
444         }
445 
446         let out = AsciiAlphanumeric::arbitrary(g).0;
447         assert!(!out.is_empty());
448         out
449     }
450 
u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8451     fn u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8 {
452         loop {
453             let size = u8::arbitrary(g);
454             if size > lower && size <= upper {
455                 break size;
456             }
457         }
458     }
459 }
460