1 use std::error::Error as StdError;
2 use std::future::Future;
3 use std::marker::Unpin;
4 use std::pin::Pin;
5 use std::task::{Context, Poll};
6 
7 use pin_project_lite::pin_project;
8 use tokio::io::{AsyncRead, AsyncWrite};
9 use tracing::debug;
10 
11 use super::accept::Accept;
12 use super::conn::UpgradeableConnection;
13 use super::server::{Server, Watcher};
14 use crate::body::{Body, HttpBody};
15 use crate::common::drain::{self, Draining, Signal, Watch, Watching};
16 use crate::common::exec::{ConnStreamExec, NewSvcExec};
17 use crate::service::{HttpService, MakeServiceRef};
18 
19 pin_project! {
20     #[allow(missing_debug_implementations)]
21     pub struct Graceful<I, S, F, E> {
22         #[pin]
23         state: State<I, S, F, E>,
24     }
25 }
26 
27 pin_project! {
28     #[project = StateProj]
29     pub(super) enum State<I, S, F, E> {
30         Running {
31             drain: Option<(Signal, Watch)>,
32             #[pin]
33             server: Server<I, S, E>,
34             #[pin]
35             signal: F,
36         },
37         Draining { draining: Draining },
38     }
39 }
40 
41 impl<I, S, F, E> Graceful<I, S, F, E> {
new(server: Server<I, S, E>, signal: F) -> Self42     pub(super) fn new(server: Server<I, S, E>, signal: F) -> Self {
43         let drain = Some(drain::channel());
44         Graceful {
45             state: State::Running {
46                 drain,
47                 server,
48                 signal,
49             },
50         }
51     }
52 }
53 
54 impl<I, IO, IE, S, B, F, E> Future for Graceful<I, S, F, E>
55 where
56     I: Accept<Conn = IO, Error = IE>,
57     IE: Into<Box<dyn StdError + Send + Sync>>,
58     IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
59     S: MakeServiceRef<IO, Body, ResBody = B>,
60     S::Error: Into<Box<dyn StdError + Send + Sync>>,
61     B: HttpBody + 'static,
62     B::Error: Into<Box<dyn StdError + Send + Sync>>,
63     F: Future<Output = ()>,
64     E: ConnStreamExec<<S::Service as HttpService<Body>>::Future, B>,
65     E: NewSvcExec<IO, S::Future, S::Service, E, GracefulWatcher>,
66 {
67     type Output = crate::Result<()>;
68 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>69     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
70         let mut me = self.project();
71         loop {
72             let next = {
73                 match me.state.as_mut().project() {
74                     StateProj::Running {
75                         drain,
76                         server,
77                         signal,
78                     } => match signal.poll(cx) {
79                         Poll::Ready(()) => {
80                             debug!("signal received, starting graceful shutdown");
81                             let sig = drain.take().expect("drain channel").0;
82                             State::Draining {
83                                 draining: sig.drain(),
84                             }
85                         }
86                         Poll::Pending => {
87                             let watch = drain.as_ref().expect("drain channel").1.clone();
88                             return server.poll_watch(cx, &GracefulWatcher(watch));
89                         }
90                     },
91                     StateProj::Draining { ref mut draining } => {
92                         return Pin::new(draining).poll(cx).map(Ok);
93                     }
94                 }
95             };
96             me.state.set(next);
97         }
98     }
99 }
100 
101 #[allow(missing_debug_implementations)]
102 #[derive(Clone)]
103 pub struct GracefulWatcher(Watch);
104 
105 impl<I, S, E> Watcher<I, S, E> for GracefulWatcher
106 where
107     I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
108     S: HttpService<Body>,
109     E: ConnStreamExec<S::Future, S::ResBody>,
110     S::ResBody: 'static,
111     <S::ResBody as HttpBody>::Error: Into<Box<dyn StdError + Send + Sync>>,
112 {
113     type Future =
114         Watching<UpgradeableConnection<I, S, E>, fn(Pin<&mut UpgradeableConnection<I, S, E>>)>;
115 
watch(&self, conn: UpgradeableConnection<I, S, E>) -> Self::Future116     fn watch(&self, conn: UpgradeableConnection<I, S, E>) -> Self::Future {
117         self.0.clone().watch(conn, on_drain)
118     }
119 }
120 
on_drain<I, S, E>(conn: Pin<&mut UpgradeableConnection<I, S, E>>) where S: HttpService<Body>, S::Error: Into<Box<dyn StdError + Send + Sync>>, I: AsyncRead + AsyncWrite + Unpin, S::ResBody: HttpBody + 'static, <S::ResBody as HttpBody>::Error: Into<Box<dyn StdError + Send + Sync>>, E: ConnStreamExec<S::Future, S::ResBody>,121 fn on_drain<I, S, E>(conn: Pin<&mut UpgradeableConnection<I, S, E>>)
122 where
123     S: HttpService<Body>,
124     S::Error: Into<Box<dyn StdError + Send + Sync>>,
125     I: AsyncRead + AsyncWrite + Unpin,
126     S::ResBody: HttpBody + 'static,
127     <S::ResBody as HttpBody>::Error: Into<Box<dyn StdError + Send + Sync>>,
128     E: ConnStreamExec<S::Future, S::ResBody>,
129 {
130     conn.graceful_shutdown()
131 }
132