1 #![allow(dead_code)]
2
3 use futures::future;
4 use std::fmt;
5 use std::pin::Pin;
6 use std::task::{Context, Poll};
7 use tokio::sync::mpsc;
8 use tokio_stream::Stream;
9 use tower::Service;
10
trace_init() -> tracing::subscriber::DefaultGuard11 pub(crate) fn trace_init() -> tracing::subscriber::DefaultGuard {
12 let subscriber = tracing_subscriber::fmt()
13 .with_test_writer()
14 .with_max_level(tracing::Level::TRACE)
15 .with_thread_names(true)
16 .finish();
17 tracing::subscriber::set_default(subscriber)
18 }
19
20 pin_project_lite::pin_project! {
21 #[derive(Clone, Debug)]
22 pub struct IntoStream<S> {
23 #[pin]
24 inner: S
25 }
26 }
27
28 impl<S> IntoStream<S> {
new(inner: S) -> Self29 pub fn new(inner: S) -> Self {
30 Self { inner }
31 }
32 }
33
34 impl<I> Stream for IntoStream<mpsc::Receiver<I>> {
35 type Item = I;
36
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>37 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
38 self.project().inner.poll_recv(cx)
39 }
40 }
41
42 impl<I> Stream for IntoStream<mpsc::UnboundedReceiver<I>> {
43 type Item = I;
44
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>45 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
46 self.project().inner.poll_recv(cx)
47 }
48 }
49
50 #[derive(Clone, Debug)]
51 pub struct AssertSpanSvc {
52 span: tracing::Span,
53 polled: bool,
54 }
55
56 pub struct AssertSpanError(String);
57
58 impl fmt::Debug for AssertSpanError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 fmt::Display::fmt(&self.0, f)
61 }
62 }
63
64 impl fmt::Display for AssertSpanError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 fmt::Display::fmt(&self.0, f)
67 }
68 }
69
70 impl std::error::Error for AssertSpanError {}
71
72 impl AssertSpanSvc {
new(span: tracing::Span) -> Self73 pub fn new(span: tracing::Span) -> Self {
74 Self {
75 span,
76 polled: false,
77 }
78 }
79
check(&self, func: &str) -> Result<(), AssertSpanError>80 fn check(&self, func: &str) -> Result<(), AssertSpanError> {
81 let current_span = tracing::Span::current();
82 tracing::debug!(?current_span, ?self.span, %func);
83 if current_span == self.span {
84 return Ok(());
85 }
86
87 Err(AssertSpanError(format!(
88 "{} called outside expected span\n expected: {:?}\n current: {:?}",
89 func, self.span, current_span
90 )))
91 }
92 }
93
94 impl Service<()> for AssertSpanSvc {
95 type Response = ();
96 type Error = AssertSpanError;
97 type Future = future::Ready<Result<Self::Response, Self::Error>>;
98
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>99 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100 if self.polled {
101 return Poll::Ready(self.check("poll_ready"));
102 }
103
104 cx.waker().wake_by_ref();
105 self.polled = true;
106 Poll::Pending
107 }
108
call(&mut self, _: ()) -> Self::Future109 fn call(&mut self, _: ()) -> Self::Future {
110 future::ready(self.check("call"))
111 }
112 }
113