1 #![allow(dead_code)]
2 
3 use futures::future;
4 use std::fmt;
5 use std::pin::Pin;
6 use std::task::{Context, Poll};
7 use tokio::sync::mpsc;
8 use tokio_stream::Stream;
9 use tower::Service;
10 
trace_init() -> tracing::subscriber::DefaultGuard11 pub(crate) fn trace_init() -> tracing::subscriber::DefaultGuard {
12     let subscriber = tracing_subscriber::fmt()
13         .with_test_writer()
14         .with_max_level(tracing::Level::TRACE)
15         .with_thread_names(true)
16         .finish();
17     tracing::subscriber::set_default(subscriber)
18 }
19 
20 pin_project_lite::pin_project! {
21     #[derive(Clone, Debug)]
22     pub struct IntoStream<S> {
23         #[pin]
24         inner: S
25     }
26 }
27 
28 impl<S> IntoStream<S> {
new(inner: S) -> Self29     pub fn new(inner: S) -> Self {
30         Self { inner }
31     }
32 }
33 
34 impl<I> Stream for IntoStream<mpsc::Receiver<I>> {
35     type Item = I;
36 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>37     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
38         self.project().inner.poll_recv(cx)
39     }
40 }
41 
42 impl<I> Stream for IntoStream<mpsc::UnboundedReceiver<I>> {
43     type Item = I;
44 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>45     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
46         self.project().inner.poll_recv(cx)
47     }
48 }
49 
50 #[derive(Clone, Debug)]
51 pub struct AssertSpanSvc {
52     span: tracing::Span,
53     polled: bool,
54 }
55 
56 pub struct AssertSpanError(String);
57 
58 impl fmt::Debug for AssertSpanError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result59     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60         fmt::Display::fmt(&self.0, f)
61     }
62 }
63 
64 impl fmt::Display for AssertSpanError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result65     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66         fmt::Display::fmt(&self.0, f)
67     }
68 }
69 
70 impl std::error::Error for AssertSpanError {}
71 
72 impl AssertSpanSvc {
new(span: tracing::Span) -> Self73     pub fn new(span: tracing::Span) -> Self {
74         Self {
75             span,
76             polled: false,
77         }
78     }
79 
check(&self, func: &str) -> Result<(), AssertSpanError>80     fn check(&self, func: &str) -> Result<(), AssertSpanError> {
81         let current_span = tracing::Span::current();
82         tracing::debug!(?current_span, ?self.span, %func);
83         if current_span == self.span {
84             return Ok(());
85         }
86 
87         Err(AssertSpanError(format!(
88             "{} called outside expected span\n expected: {:?}\n  current: {:?}",
89             func, self.span, current_span
90         )))
91     }
92 }
93 
94 impl Service<()> for AssertSpanSvc {
95     type Response = ();
96     type Error = AssertSpanError;
97     type Future = future::Ready<Result<Self::Response, Self::Error>>;
98 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>99     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100         if self.polled {
101             return Poll::Ready(self.check("poll_ready"));
102         }
103 
104         cx.waker().wake_by_ref();
105         self.polled = true;
106         Poll::Pending
107     }
108 
call(&mut self, _: ()) -> Self::Future109     fn call(&mut self, _: ()) -> Self::Future {
110         future::ready(self.check("call"))
111     }
112 }
113