1 //! A [`Load`] implementation that measures load using the number of in-flight requests.
2 
3 #[cfg(feature = "discover")]
4 use crate::discover::{Change, Discover};
5 #[cfg(feature = "discover")]
6 use futures_core::{ready, Stream};
7 #[cfg(feature = "discover")]
8 use pin_project_lite::pin_project;
9 #[cfg(feature = "discover")]
10 use std::pin::Pin;
11 
12 use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13 use super::Load;
14 use std::sync::Arc;
15 use std::task::{Context, Poll};
16 use tower_service::Service;
17 
18 /// Measures the load of the underlying service using the number of currently-pending requests.
19 #[derive(Debug)]
20 pub struct PendingRequests<S, C = CompleteOnResponse> {
21     service: S,
22     ref_count: RefCount,
23     completion: C,
24 }
25 
26 /// Shared between instances of [`PendingRequests`] and [`Handle`] to track active references.
27 #[derive(Clone, Debug, Default)]
28 struct RefCount(Arc<()>);
29 
30 #[cfg(feature = "discover")]
31 pin_project! {
32     /// Wraps a `D`-typed stream of discovered services with [`PendingRequests`].
33     #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
34     #[derive(Debug)]
35     pub struct PendingRequestsDiscover<D, C = CompleteOnResponse> {
36         #[pin]
37         discover: D,
38         completion: C,
39     }
40 }
41 
42 /// Represents the number of currently-pending requests to a given service.
43 #[derive(Clone, Copy, Debug, Default, PartialOrd, PartialEq, Ord, Eq)]
44 pub struct Count(usize);
45 
46 /// Tracks an in-flight request by reference count.
47 #[derive(Debug)]
48 pub struct Handle(RefCount);
49 
50 // ===== impl PendingRequests =====
51 
52 impl<S, C> PendingRequests<S, C> {
53     /// Wraps an `S`-typed service so that its load is tracked by the number of pending requests.
new(service: S, completion: C) -> Self54     pub fn new(service: S, completion: C) -> Self {
55         Self {
56             service,
57             completion,
58             ref_count: RefCount::default(),
59         }
60     }
61 
handle(&self) -> Handle62     fn handle(&self) -> Handle {
63         Handle(self.ref_count.clone())
64     }
65 }
66 
67 impl<S, C> Load for PendingRequests<S, C> {
68     type Metric = Count;
69 
load(&self) -> Count70     fn load(&self) -> Count {
71         // Count the number of references that aren't `self`.
72         Count(self.ref_count.ref_count() - 1)
73     }
74 }
75 
76 impl<S, C, Request> Service<Request> for PendingRequests<S, C>
77 where
78     S: Service<Request>,
79     C: TrackCompletion<Handle, S::Response>,
80 {
81     type Response = C::Output;
82     type Error = S::Error;
83     type Future = TrackCompletionFuture<S::Future, C, Handle>;
84 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>85     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86         self.service.poll_ready(cx)
87     }
88 
call(&mut self, req: Request) -> Self::Future89     fn call(&mut self, req: Request) -> Self::Future {
90         TrackCompletionFuture::new(
91             self.completion.clone(),
92             self.handle(),
93             self.service.call(req),
94         )
95     }
96 }
97 
98 // ===== impl PendingRequestsDiscover =====
99 
100 #[cfg(feature = "discover")]
101 impl<D, C> PendingRequestsDiscover<D, C> {
102     /// Wraps a [`Discover`], wrapping all of its services with [`PendingRequests`].
new<Request>(discover: D, completion: C) -> Self where D: Discover, D::Service: Service<Request>, C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,103     pub fn new<Request>(discover: D, completion: C) -> Self
104     where
105         D: Discover,
106         D::Service: Service<Request>,
107         C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
108     {
109         Self {
110             discover,
111             completion,
112         }
113     }
114 }
115 
116 #[cfg(feature = "discover")]
117 impl<D, C> Stream for PendingRequestsDiscover<D, C>
118 where
119     D: Discover,
120     C: Clone,
121 {
122     type Item = Result<Change<D::Key, PendingRequests<D::Service, C>>, D::Error>;
123 
124     /// Yields the next discovery change set.
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>125     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
126         use self::Change::*;
127 
128         let this = self.project();
129         let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
130             None => return Poll::Ready(None),
131             Some(Insert(k, svc)) => Insert(k, PendingRequests::new(svc, this.completion.clone())),
132             Some(Remove(k)) => Remove(k),
133         };
134 
135         Poll::Ready(Some(Ok(change)))
136     }
137 }
138 
139 // ==== RefCount ====
140 
141 impl RefCount {
ref_count(&self) -> usize142     pub(crate) fn ref_count(&self) -> usize {
143         Arc::strong_count(&self.0)
144     }
145 }
146 
147 #[cfg(test)]
148 mod tests {
149     use super::*;
150     use futures_util::future;
151     use std::task::{Context, Poll};
152 
153     struct Svc;
154     impl Service<()> for Svc {
155         type Response = ();
156         type Error = ();
157         type Future = future::Ready<Result<(), ()>>;
158 
poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>>159         fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
160             Poll::Ready(Ok(()))
161         }
162 
call(&mut self, (): ()) -> Self::Future163         fn call(&mut self, (): ()) -> Self::Future {
164             future::ok(())
165         }
166     }
167 
168     #[test]
default()169     fn default() {
170         let mut svc = PendingRequests::new(Svc, CompleteOnResponse);
171         assert_eq!(svc.load(), Count(0));
172 
173         let rsp0 = svc.call(());
174         assert_eq!(svc.load(), Count(1));
175 
176         let rsp1 = svc.call(());
177         assert_eq!(svc.load(), Count(2));
178 
179         let () = tokio_test::block_on(rsp0).unwrap();
180         assert_eq!(svc.load(), Count(1));
181 
182         let () = tokio_test::block_on(rsp1).unwrap();
183         assert_eq!(svc.load(), Count(0));
184     }
185 
186     #[test]
with_completion()187     fn with_completion() {
188         #[derive(Clone)]
189         struct IntoHandle;
190         impl TrackCompletion<Handle, ()> for IntoHandle {
191             type Output = Handle;
192             fn track_completion(&self, i: Handle, (): ()) -> Handle {
193                 i
194             }
195         }
196 
197         let mut svc = PendingRequests::new(Svc, IntoHandle);
198         assert_eq!(svc.load(), Count(0));
199 
200         let rsp = svc.call(());
201         assert_eq!(svc.load(), Count(1));
202         let i0 = tokio_test::block_on(rsp).unwrap();
203         assert_eq!(svc.load(), Count(1));
204 
205         let rsp = svc.call(());
206         assert_eq!(svc.load(), Count(2));
207         let i1 = tokio_test::block_on(rsp).unwrap();
208         assert_eq!(svc.load(), Count(2));
209 
210         drop(i1);
211         assert_eq!(svc.load(), Count(1));
212 
213         drop(i0);
214         assert_eq!(svc.load(), Count(0));
215     }
216 }
217