1 //! A [`Load`] implementation that measures load using the number of in-flight requests. 2 3 #[cfg(feature = "discover")] 4 use crate::discover::{Change, Discover}; 5 #[cfg(feature = "discover")] 6 use futures_core::{ready, Stream}; 7 #[cfg(feature = "discover")] 8 use pin_project_lite::pin_project; 9 #[cfg(feature = "discover")] 10 use std::pin::Pin; 11 12 use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture}; 13 use super::Load; 14 use std::sync::Arc; 15 use std::task::{Context, Poll}; 16 use tower_service::Service; 17 18 /// Measures the load of the underlying service using the number of currently-pending requests. 19 #[derive(Debug)] 20 pub struct PendingRequests<S, C = CompleteOnResponse> { 21 service: S, 22 ref_count: RefCount, 23 completion: C, 24 } 25 26 /// Shared between instances of [`PendingRequests`] and [`Handle`] to track active references. 27 #[derive(Clone, Debug, Default)] 28 struct RefCount(Arc<()>); 29 30 #[cfg(feature = "discover")] 31 pin_project! { 32 /// Wraps a `D`-typed stream of discovered services with [`PendingRequests`]. 33 #[cfg_attr(docsrs, doc(cfg(feature = "discover")))] 34 #[derive(Debug)] 35 pub struct PendingRequestsDiscover<D, C = CompleteOnResponse> { 36 #[pin] 37 discover: D, 38 completion: C, 39 } 40 } 41 42 /// Represents the number of currently-pending requests to a given service. 43 #[derive(Clone, Copy, Debug, Default, PartialOrd, PartialEq, Ord, Eq)] 44 pub struct Count(usize); 45 46 /// Tracks an in-flight request by reference count. 47 #[derive(Debug)] 48 pub struct Handle(RefCount); 49 50 // ===== impl PendingRequests ===== 51 52 impl<S, C> PendingRequests<S, C> { 53 /// Wraps an `S`-typed service so that its load is tracked by the number of pending requests. new(service: S, completion: C) -> Self54 pub fn new(service: S, completion: C) -> Self { 55 Self { 56 service, 57 completion, 58 ref_count: RefCount::default(), 59 } 60 } 61 handle(&self) -> Handle62 fn handle(&self) -> Handle { 63 Handle(self.ref_count.clone()) 64 } 65 } 66 67 impl<S, C> Load for PendingRequests<S, C> { 68 type Metric = Count; 69 load(&self) -> Count70 fn load(&self) -> Count { 71 // Count the number of references that aren't `self`. 72 Count(self.ref_count.ref_count() - 1) 73 } 74 } 75 76 impl<S, C, Request> Service<Request> for PendingRequests<S, C> 77 where 78 S: Service<Request>, 79 C: TrackCompletion<Handle, S::Response>, 80 { 81 type Response = C::Output; 82 type Error = S::Error; 83 type Future = TrackCompletionFuture<S::Future, C, Handle>; 84 poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>85 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 86 self.service.poll_ready(cx) 87 } 88 call(&mut self, req: Request) -> Self::Future89 fn call(&mut self, req: Request) -> Self::Future { 90 TrackCompletionFuture::new( 91 self.completion.clone(), 92 self.handle(), 93 self.service.call(req), 94 ) 95 } 96 } 97 98 // ===== impl PendingRequestsDiscover ===== 99 100 #[cfg(feature = "discover")] 101 impl<D, C> PendingRequestsDiscover<D, C> { 102 /// Wraps a [`Discover`], wrapping all of its services with [`PendingRequests`]. new<Request>(discover: D, completion: C) -> Self where D: Discover, D::Service: Service<Request>, C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,103 pub fn new<Request>(discover: D, completion: C) -> Self 104 where 105 D: Discover, 106 D::Service: Service<Request>, 107 C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>, 108 { 109 Self { 110 discover, 111 completion, 112 } 113 } 114 } 115 116 #[cfg(feature = "discover")] 117 impl<D, C> Stream for PendingRequestsDiscover<D, C> 118 where 119 D: Discover, 120 C: Clone, 121 { 122 type Item = Result<Change<D::Key, PendingRequests<D::Service, C>>, D::Error>; 123 124 /// Yields the next discovery change set. poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>125 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 126 use self::Change::*; 127 128 let this = self.project(); 129 let change = match ready!(this.discover.poll_discover(cx)).transpose()? { 130 None => return Poll::Ready(None), 131 Some(Insert(k, svc)) => Insert(k, PendingRequests::new(svc, this.completion.clone())), 132 Some(Remove(k)) => Remove(k), 133 }; 134 135 Poll::Ready(Some(Ok(change))) 136 } 137 } 138 139 // ==== RefCount ==== 140 141 impl RefCount { ref_count(&self) -> usize142 pub(crate) fn ref_count(&self) -> usize { 143 Arc::strong_count(&self.0) 144 } 145 } 146 147 #[cfg(test)] 148 mod tests { 149 use super::*; 150 use futures_util::future; 151 use std::task::{Context, Poll}; 152 153 struct Svc; 154 impl Service<()> for Svc { 155 type Response = (); 156 type Error = (); 157 type Future = future::Ready<Result<(), ()>>; 158 poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>>159 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> { 160 Poll::Ready(Ok(())) 161 } 162 call(&mut self, (): ()) -> Self::Future163 fn call(&mut self, (): ()) -> Self::Future { 164 future::ok(()) 165 } 166 } 167 168 #[test] default()169 fn default() { 170 let mut svc = PendingRequests::new(Svc, CompleteOnResponse); 171 assert_eq!(svc.load(), Count(0)); 172 173 let rsp0 = svc.call(()); 174 assert_eq!(svc.load(), Count(1)); 175 176 let rsp1 = svc.call(()); 177 assert_eq!(svc.load(), Count(2)); 178 179 let () = tokio_test::block_on(rsp0).unwrap(); 180 assert_eq!(svc.load(), Count(1)); 181 182 let () = tokio_test::block_on(rsp1).unwrap(); 183 assert_eq!(svc.load(), Count(0)); 184 } 185 186 #[test] with_completion()187 fn with_completion() { 188 #[derive(Clone)] 189 struct IntoHandle; 190 impl TrackCompletion<Handle, ()> for IntoHandle { 191 type Output = Handle; 192 fn track_completion(&self, i: Handle, (): ()) -> Handle { 193 i 194 } 195 } 196 197 let mut svc = PendingRequests::new(Svc, IntoHandle); 198 assert_eq!(svc.load(), Count(0)); 199 200 let rsp = svc.call(()); 201 assert_eq!(svc.load(), Count(1)); 202 let i0 = tokio_test::block_on(rsp).unwrap(); 203 assert_eq!(svc.load(), Count(1)); 204 205 let rsp = svc.call(()); 206 assert_eq!(svc.load(), Count(2)); 207 let i1 = tokio_test::block_on(rsp).unwrap(); 208 assert_eq!(svc.load(), Count(2)); 209 210 drop(i1); 211 assert_eq!(svc.load(), Count(1)); 212 213 drop(i0); 214 assert_eq!(svc.load(), Count(0)); 215 } 216 } 217