1 //! This module provides functionality to aid managing routing requests between [`Service`]s.
2 //!
3 //! # Example
4 //!
5 //! [`Steer`] can for example be used to create a router, akin to what you might find in web
6 //! frameworks.
7 //!
8 //! Here, `GET /` will be sent to the `root` service, while all other requests go to `not_found`.
9 //!
10 //! ```rust
11 //! # use std::task::{Context, Poll};
12 //! # use tower_service::Service;
13 //! # use futures_util::future::{ready, Ready, poll_fn};
14 //! # use tower::steer::Steer;
15 //! # use tower::service_fn;
16 //! # use tower::util::BoxService;
17 //! # use tower::ServiceExt;
18 //! # use std::convert::Infallible;
19 //! use http::{Request, Response, StatusCode, Method};
20 //!
21 //! # #[tokio::main]
22 //! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
23 //! // Service that responds to `GET /`
24 //! let root = service_fn(|req: Request<String>| async move {
25 //!     # assert_eq!(req.uri().path(), "/");
26 //!     let res = Response::new("Hello, World!".to_string());
27 //!     Ok::<_, Infallible>(res)
28 //! });
29 //! // We have to box the service so its type gets erased and we can put it in a `Vec` with other
30 //! // services
31 //! let root = BoxService::new(root);
32 //!
33 //! // Service that responds with `404 Not Found` to all requests
34 //! let not_found = service_fn(|req: Request<String>| async move {
35 //!     let res = Response::builder()
36 //!         .status(StatusCode::NOT_FOUND)
37 //!         .body(String::new())
38 //!         .expect("response is valid");
39 //!     Ok::<_, Infallible>(res)
40 //! });
41 //! // Box that as well
42 //! let not_found = BoxService::new(not_found);
43 //!
44 //! let mut svc = Steer::new(
45 //!     // All services we route between
46 //!     vec![root, not_found],
47 //!     // How we pick which service to send the request to
48 //!     |req: &Request<String>, _services: &[_]| {
49 //!         if req.method() == Method::GET && req.uri().path() == "/" {
50 //!             0 // Index of `root`
51 //!         } else {
52 //!             1 // Index of `not_found`
53 //!         }
54 //!     },
55 //! );
56 //!
57 //! // This request will get sent to `root`
58 //! let req = Request::get("/").body(String::new()).unwrap();
59 //! let res = svc.ready().await?.call(req).await?;
60 //! assert_eq!(res.into_body(), "Hello, World!");
61 //!
62 //! // This request will get sent to `not_found`
63 //! let req = Request::get("/does/not/exist").body(String::new()).unwrap();
64 //! let res = svc.ready().await?.call(req).await?;
65 //! assert_eq!(res.status(), StatusCode::NOT_FOUND);
66 //! assert_eq!(res.into_body(), "");
67 //! #
68 //! # Ok(())
69 //! # }
70 //! ```
71 use std::task::{Context, Poll};
72 use std::{collections::VecDeque, fmt, marker::PhantomData};
73 use tower_service::Service;
74 
75 /// This is how callers of [`Steer`] tell it which `Service` a `Req` corresponds to.
76 pub trait Picker<S, Req> {
77     /// Return an index into the iterator of `Service` passed to [`Steer::new`].
pick(&mut self, r: &Req, services: &[S]) -> usize78     fn pick(&mut self, r: &Req, services: &[S]) -> usize;
79 }
80 
81 impl<S, F, Req> Picker<S, Req> for F
82 where
83     F: Fn(&Req, &[S]) -> usize,
84 {
pick(&mut self, r: &Req, services: &[S]) -> usize85     fn pick(&mut self, r: &Req, services: &[S]) -> usize {
86         self(r, services)
87     }
88 }
89 
90 /// [`Steer`] manages a list of [`Service`]s which all handle the same type of request.
91 ///
92 /// An example use case is a sharded service.
93 /// It accepts new requests, then:
94 /// 1. Determines, via the provided [`Picker`], which [`Service`] the request coresponds to.
95 /// 2. Waits (in [`Service::poll_ready`]) for *all* services to be ready.
96 /// 3. Calls the correct [`Service`] with the request, and returns a future corresponding to the
97 ///    call.
98 ///
99 /// Note that [`Steer`] must wait for all services to be ready since it can't know ahead of time
100 /// which [`Service`] the next message will arrive for, and is unwilling to buffer items
101 /// indefinitely. This will cause head-of-line blocking unless paired with a [`Service`] that does
102 /// buffer items indefinitely, and thus always returns [`Poll::Ready`]. For example, wrapping each
103 /// component service with a [`Buffer`] with a high enough limit (the maximum number of concurrent
104 /// requests) will prevent head-of-line blocking in [`Steer`].
105 ///
106 /// [`Buffer`]: crate::buffer::Buffer
107 pub struct Steer<S, F, Req> {
108     router: F,
109     services: Vec<S>,
110     not_ready: VecDeque<usize>,
111     _phantom: PhantomData<Req>,
112 }
113 
114 impl<S, F, Req> Steer<S, F, Req> {
115     /// Make a new [`Steer`] with a list of [`Service`]'s and a [`Picker`].
116     ///
117     /// Note: the order of the [`Service`]'s is significant for [`Picker::pick`]'s return value.
new(services: impl IntoIterator<Item = S>, router: F) -> Self118     pub fn new(services: impl IntoIterator<Item = S>, router: F) -> Self {
119         let services: Vec<_> = services.into_iter().collect();
120         let not_ready: VecDeque<_> = services.iter().enumerate().map(|(i, _)| i).collect();
121         Self {
122             router,
123             services,
124             not_ready,
125             _phantom: PhantomData,
126         }
127     }
128 }
129 
130 impl<S, Req, F> Service<Req> for Steer<S, F, Req>
131 where
132     S: Service<Req>,
133     F: Picker<S, Req>,
134 {
135     type Response = S::Response;
136     type Error = S::Error;
137     type Future = S::Future;
138 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>139     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140         loop {
141             // must wait for *all* services to be ready.
142             // this will cause head-of-line blocking unless the underlying services are always ready.
143             if self.not_ready.is_empty() {
144                 return Poll::Ready(Ok(()));
145             } else {
146                 if self.services[self.not_ready[0]]
147                     .poll_ready(cx)?
148                     .is_pending()
149                 {
150                     return Poll::Pending;
151                 }
152 
153                 self.not_ready.pop_front();
154             }
155         }
156     }
157 
call(&mut self, req: Req) -> Self::Future158     fn call(&mut self, req: Req) -> Self::Future {
159         assert!(
160             self.not_ready.is_empty(),
161             "Steer must wait for all services to be ready. Did you forget to call poll_ready()?"
162         );
163 
164         let idx = self.router.pick(&req, &self.services[..]);
165         let cl = &mut self.services[idx];
166         self.not_ready.push_back(idx);
167         cl.call(req)
168     }
169 }
170 
171 impl<S, F, Req> Clone for Steer<S, F, Req>
172 where
173     S: Clone,
174     F: Clone,
175 {
clone(&self) -> Self176     fn clone(&self) -> Self {
177         Self {
178             router: self.router.clone(),
179             services: self.services.clone(),
180             not_ready: self.not_ready.clone(),
181             _phantom: PhantomData,
182         }
183     }
184 }
185 
186 impl<S, F, Req> fmt::Debug for Steer<S, F, Req>
187 where
188     S: fmt::Debug,
189     F: fmt::Debug,
190 {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result191     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192         let Self {
193             router,
194             services,
195             not_ready,
196             _phantom,
197         } = self;
198         f.debug_struct("Steer")
199             .field("router", router)
200             .field("services", services)
201             .field("not_ready", not_ready)
202             .finish()
203     }
204 }
205