1 use http::{Request, Uri};
2 use std::{
3 sync::Arc,
4 task::{Context, Poll},
5 };
6 use tower::Layer;
7 use tower_layer::layer_fn;
8 use tower_service::Service;
9
10 #[derive(Clone)]
11 pub(super) struct StripPrefix<S> {
12 inner: S,
13 prefix: Arc<str>,
14 }
15
16 impl<S> StripPrefix<S> {
new(inner: S, prefix: &str) -> Self17 pub(super) fn new(inner: S, prefix: &str) -> Self {
18 Self {
19 inner,
20 prefix: prefix.into(),
21 }
22 }
23
layer(prefix: &str) -> impl Layer<S, Service = Self> + Clone24 pub(super) fn layer(prefix: &str) -> impl Layer<S, Service = Self> + Clone {
25 let prefix = Arc::from(prefix);
26 layer_fn(move |inner| Self {
27 inner,
28 prefix: Arc::clone(&prefix),
29 })
30 }
31 }
32
33 impl<S, B> Service<Request<B>> for StripPrefix<S>
34 where
35 S: Service<Request<B>>,
36 {
37 type Response = S::Response;
38 type Error = S::Error;
39 type Future = S::Future;
40
41 #[inline]
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>42 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
43 self.inner.poll_ready(cx)
44 }
45
call(&mut self, mut req: Request<B>) -> Self::Future46 fn call(&mut self, mut req: Request<B>) -> Self::Future {
47 if let Some(new_uri) = strip_prefix(req.uri(), &self.prefix) {
48 *req.uri_mut() = new_uri;
49 }
50 self.inner.call(req)
51 }
52 }
53
strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri>54 fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
55 let path_and_query = uri.path_and_query()?;
56
57 // Check whether the prefix matches the path and if so how long the matching prefix is.
58 //
59 // For example:
60 //
61 // prefix = /api
62 // path = /api/users
63 // ^^^^ this much is matched and the length is 4. Thus if we chop off the first 4
64 // characters we get the remainder
65 //
66 // prefix = /api/:version
67 // path = /api/v0/users
68 // ^^^^^^^ this much is matched and the length is 7.
69 let mut matching_prefix_length = Some(0);
70 for item in zip_longest(segments(path_and_query.path()), segments(prefix)) {
71 // count the `/`
72 *matching_prefix_length.as_mut().unwrap() += 1;
73
74 match item {
75 Item::Both(path_segment, prefix_segment) => {
76 if prefix_segment.starts_with(':') || path_segment == prefix_segment {
77 // the prefix segment is either a param, which matches anything, or
78 // it actually matches the path segment
79 *matching_prefix_length.as_mut().unwrap() += path_segment.len();
80 } else if prefix_segment.is_empty() {
81 // the prefix ended in a `/` so we got a match.
82 //
83 // For example:
84 //
85 // prefix = /foo/
86 // path = /foo/bar
87 //
88 // The prefix matches and the new path should be `/bar`
89 break;
90 } else {
91 // the prefix segment didn't match so there is no match
92 matching_prefix_length = None;
93 break;
94 }
95 }
96 // the path had more segments than the prefix but we got a match.
97 //
98 // For example:
99 //
100 // prefix = /foo
101 // path = /foo/bar
102 Item::First(_) => {
103 break;
104 }
105 // the prefix had more segments than the path so there is no match
106 Item::Second(_) => {
107 matching_prefix_length = None;
108 break;
109 }
110 }
111 }
112
113 // if the prefix matches it will always do so up until a `/`, it cannot match only
114 // part of a segment. Therefore this will always be at a char boundary and `split_at` wont
115 // panic
116 let after_prefix = uri.path().split_at(matching_prefix_length?).1;
117
118 let new_path_and_query = match (after_prefix.starts_with('/'), path_and_query.query()) {
119 (true, None) => after_prefix.parse().unwrap(),
120 (true, Some(query)) => format!("{after_prefix}?{query}").parse().unwrap(),
121 (false, None) => format!("/{after_prefix}").parse().unwrap(),
122 (false, Some(query)) => format!("/{after_prefix}?{query}").parse().unwrap(),
123 };
124
125 let mut parts = uri.clone().into_parts();
126 parts.path_and_query = Some(new_path_and_query);
127
128 Some(Uri::from_parts(parts).unwrap())
129 }
130
segments(s: &str) -> impl Iterator<Item = &str>131 fn segments(s: &str) -> impl Iterator<Item = &str> {
132 assert!(
133 s.starts_with('/'),
134 "path didn't start with '/'. axum should have caught this higher up."
135 );
136
137 s.split('/')
138 // skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"]
139 // otherwise
140 .skip(1)
141 }
142
zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>> where I: Iterator, I2: Iterator<Item = I::Item>,143 fn zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>>
144 where
145 I: Iterator,
146 I2: Iterator<Item = I::Item>,
147 {
148 let a = a.map(Some).chain(std::iter::repeat_with(|| None));
149 let b = b.map(Some).chain(std::iter::repeat_with(|| None));
150 a.zip(b).map_while(|(a, b)| match (a, b) {
151 (Some(a), Some(b)) => Some(Item::Both(a, b)),
152 (Some(a), None) => Some(Item::First(a)),
153 (None, Some(b)) => Some(Item::Second(b)),
154 (None, None) => None,
155 })
156 }
157
158 #[derive(Debug)]
159 enum Item<T> {
160 Both(T, T),
161 First(T),
162 Second(T),
163 }
164
165 #[cfg(test)]
166 mod tests {
167 #[allow(unused_imports)]
168 use super::*;
169 use quickcheck::Arbitrary;
170 use quickcheck_macros::quickcheck;
171
172 macro_rules! test {
173 (
174 $name:ident,
175 uri = $uri:literal,
176 prefix = $prefix:literal,
177 expected = $expected:expr,
178 ) => {
179 #[test]
180 fn $name() {
181 let uri = $uri.parse().unwrap();
182 let new_uri = strip_prefix(&uri, $prefix).map(|uri| uri.to_string());
183 assert_eq!(new_uri.as_deref(), $expected);
184 }
185 };
186 }
187
188 test!(empty, uri = "/", prefix = "/", expected = Some("/"),);
189
190 test!(
191 single_segment,
192 uri = "/a",
193 prefix = "/a",
194 expected = Some("/"),
195 );
196
197 test!(
198 single_segment_root_uri,
199 uri = "/",
200 prefix = "/a",
201 expected = None,
202 );
203
204 // the prefix is empty, so removing it should have no effect
205 test!(
206 single_segment_root_prefix,
207 uri = "/a",
208 prefix = "/",
209 expected = Some("/a"),
210 );
211
212 test!(
213 single_segment_no_match,
214 uri = "/a",
215 prefix = "/b",
216 expected = None,
217 );
218
219 test!(
220 single_segment_trailing_slash,
221 uri = "/a/",
222 prefix = "/a/",
223 expected = Some("/"),
224 );
225
226 test!(
227 single_segment_trailing_slash_2,
228 uri = "/a",
229 prefix = "/a/",
230 expected = None,
231 );
232
233 test!(
234 single_segment_trailing_slash_3,
235 uri = "/a/",
236 prefix = "/a",
237 expected = Some("/"),
238 );
239
240 test!(
241 multi_segment,
242 uri = "/a/b",
243 prefix = "/a",
244 expected = Some("/b"),
245 );
246
247 test!(
248 multi_segment_2,
249 uri = "/b/a",
250 prefix = "/a",
251 expected = None,
252 );
253
254 test!(
255 multi_segment_3,
256 uri = "/a",
257 prefix = "/a/b",
258 expected = None,
259 );
260
261 test!(
262 multi_segment_4,
263 uri = "/a/b",
264 prefix = "/b",
265 expected = None,
266 );
267
268 test!(
269 multi_segment_trailing_slash,
270 uri = "/a/b/",
271 prefix = "/a/b/",
272 expected = Some("/"),
273 );
274
275 test!(
276 multi_segment_trailing_slash_2,
277 uri = "/a/b",
278 prefix = "/a/b/",
279 expected = None,
280 );
281
282 test!(
283 multi_segment_trailing_slash_3,
284 uri = "/a/b/",
285 prefix = "/a/b",
286 expected = Some("/"),
287 );
288
289 test!(param_0, uri = "/", prefix = "/:param", expected = Some("/"),);
290
291 test!(
292 param_1,
293 uri = "/a",
294 prefix = "/:param",
295 expected = Some("/"),
296 );
297
298 test!(
299 param_2,
300 uri = "/a/b",
301 prefix = "/:param",
302 expected = Some("/b"),
303 );
304
305 test!(
306 param_3,
307 uri = "/b/a",
308 prefix = "/:param",
309 expected = Some("/a"),
310 );
311
312 test!(
313 param_4,
314 uri = "/a/b",
315 prefix = "/a/:param",
316 expected = Some("/"),
317 );
318
319 test!(param_5, uri = "/b/a", prefix = "/a/:param", expected = None,);
320
321 test!(param_6, uri = "/a/b", prefix = "/:param/a", expected = None,);
322
323 test!(
324 param_7,
325 uri = "/b/a",
326 prefix = "/:param/a",
327 expected = Some("/"),
328 );
329
330 test!(
331 param_8,
332 uri = "/a/b/c",
333 prefix = "/a/:param/c",
334 expected = Some("/"),
335 );
336
337 test!(
338 param_9,
339 uri = "/c/b/a",
340 prefix = "/a/:param/c",
341 expected = None,
342 );
343
344 test!(
345 param_10,
346 uri = "/a/",
347 prefix = "/:param",
348 expected = Some("/"),
349 );
350
351 test!(param_11, uri = "/a", prefix = "/:param/", expected = None,);
352
353 test!(
354 param_12,
355 uri = "/a/",
356 prefix = "/:param/",
357 expected = Some("/"),
358 );
359
360 test!(
361 param_13,
362 uri = "/a/a",
363 prefix = "/a/",
364 expected = Some("/a"),
365 );
366
367 #[quickcheck]
does_not_panic(uri_and_prefix: UriAndPrefix) -> bool368 fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool {
369 let UriAndPrefix { uri, prefix } = uri_and_prefix;
370 strip_prefix(&uri, &prefix);
371 true
372 }
373
374 #[derive(Clone, Debug)]
375 struct UriAndPrefix {
376 uri: Uri,
377 prefix: String,
378 }
379
380 impl Arbitrary for UriAndPrefix {
arbitrary(g: &mut quickcheck::Gen) -> Self381 fn arbitrary(g: &mut quickcheck::Gen) -> Self {
382 let mut uri = String::new();
383 let mut prefix = String::new();
384
385 let size = u8_between(1, 20, g);
386
387 for _ in 0..size {
388 let segment = ascii_alphanumeric(g);
389
390 uri.push('/');
391 uri.push_str(&segment);
392
393 prefix.push('/');
394
395 let make_matching_segment = bool::arbitrary(g);
396 let make_capture = bool::arbitrary(g);
397
398 match (make_matching_segment, make_capture) {
399 (_, true) => {
400 prefix.push_str(":a");
401 }
402 (true, false) => {
403 prefix.push_str(&segment);
404 }
405 (false, false) => {
406 prefix.push_str(&ascii_alphanumeric(g));
407 }
408 }
409 }
410
411 if bool::arbitrary(g) {
412 uri.push('/');
413 }
414
415 if bool::arbitrary(g) {
416 prefix.push('/');
417 }
418
419 Self {
420 uri: uri.parse().unwrap(),
421 prefix,
422 }
423 }
424 }
425
ascii_alphanumeric(g: &mut quickcheck::Gen) -> String426 fn ascii_alphanumeric(g: &mut quickcheck::Gen) -> String {
427 #[derive(Clone)]
428 struct AsciiAlphanumeric(String);
429
430 impl Arbitrary for AsciiAlphanumeric {
431 fn arbitrary(g: &mut quickcheck::Gen) -> Self {
432 let mut out = String::new();
433
434 let size = u8_between(1, 20, g) as usize;
435
436 while out.len() < size {
437 let c = char::arbitrary(g);
438 if c.is_ascii_alphanumeric() {
439 out.push(c);
440 }
441 }
442 Self(out)
443 }
444 }
445
446 let out = AsciiAlphanumeric::arbitrary(g).0;
447 assert!(!out.is_empty());
448 out
449 }
450
u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8451 fn u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8 {
452 loop {
453 let size = u8::arbitrary(g);
454 if size > lower && size <= upper {
455 break size;
456 }
457 }
458 }
459 }
460