1 //! Server-Sent Events (SSE) responses.
2 //!
3 //! # Example
4 //!
5 //! ```
6 //! use axum::{
7 //!     Router,
8 //!     routing::get,
9 //!     response::sse::{Event, KeepAlive, Sse},
10 //! };
11 //! use std::{time::Duration, convert::Infallible};
12 //! use tokio_stream::StreamExt as _ ;
13 //! use futures_util::stream::{self, Stream};
14 //!
15 //! let app = Router::new().route("/sse", get(sse_handler));
16 //!
17 //! async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
18 //!     // A `Stream` that repeats an event every second
19 //!     let stream = stream::repeat_with(|| Event::default().data("hi!"))
20 //!         .map(Ok)
21 //!         .throttle(Duration::from_secs(1));
22 //!
23 //!     Sse::new(stream).keep_alive(KeepAlive::default())
24 //! }
25 //! # async {
26 //! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
27 //! # };
28 //! ```
29 
30 use crate::{
31     body::{Bytes, HttpBody},
32     BoxError,
33 };
34 use axum_core::{
35     body,
36     response::{IntoResponse, Response},
37 };
38 use bytes::{BufMut, BytesMut};
39 use futures_util::{
40     ready,
41     stream::{Stream, TryStream},
42 };
43 use pin_project_lite::pin_project;
44 use std::{
45     fmt,
46     future::Future,
47     pin::Pin,
48     task::{Context, Poll},
49     time::Duration,
50 };
51 use sync_wrapper::SyncWrapper;
52 use tokio::time::Sleep;
53 
54 /// An SSE response
55 #[derive(Clone)]
56 #[must_use]
57 pub struct Sse<S> {
58     stream: S,
59     keep_alive: Option<KeepAlive>,
60 }
61 
62 impl<S> Sse<S> {
63     /// Create a new [`Sse`] response that will respond with the given stream of
64     /// [`Event`]s.
65     ///
66     /// See the [module docs](self) for more details.
new(stream: S) -> Self where S: TryStream<Ok = Event> + Send + 'static, S::Error: Into<BoxError>,67     pub fn new(stream: S) -> Self
68     where
69         S: TryStream<Ok = Event> + Send + 'static,
70         S::Error: Into<BoxError>,
71     {
72         Sse {
73             stream,
74             keep_alive: None,
75         }
76     }
77 
78     /// Configure the interval between keep-alive messages.
79     ///
80     /// Defaults to no keep-alive messages.
keep_alive(mut self, keep_alive: KeepAlive) -> Self81     pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
82         self.keep_alive = Some(keep_alive);
83         self
84     }
85 }
86 
87 impl<S> fmt::Debug for Sse<S> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result88     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89         f.debug_struct("Sse")
90             .field("stream", &format_args!("{}", std::any::type_name::<S>()))
91             .field("keep_alive", &self.keep_alive)
92             .finish()
93     }
94 }
95 
96 impl<S, E> IntoResponse for Sse<S>
97 where
98     S: Stream<Item = Result<Event, E>> + Send + 'static,
99     E: Into<BoxError>,
100 {
into_response(self) -> Response101     fn into_response(self) -> Response {
102         (
103             [
104                 (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()),
105                 (http::header::CACHE_CONTROL, "no-cache"),
106             ],
107             body::boxed(Body {
108                 event_stream: SyncWrapper::new(self.stream),
109                 keep_alive: self.keep_alive.map(KeepAliveStream::new),
110             }),
111         )
112             .into_response()
113     }
114 }
115 
116 pin_project! {
117     struct Body<S> {
118         #[pin]
119         event_stream: SyncWrapper<S>,
120         #[pin]
121         keep_alive: Option<KeepAliveStream>,
122     }
123 }
124 
125 impl<S, E> HttpBody for Body<S>
126 where
127     S: Stream<Item = Result<Event, E>>,
128 {
129     type Data = Bytes;
130     type Error = E;
131 
poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>132     fn poll_data(
133         self: Pin<&mut Self>,
134         cx: &mut Context<'_>,
135     ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
136         let this = self.project();
137 
138         match this.event_stream.get_pin_mut().poll_next(cx) {
139             Poll::Pending => {
140                 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
141                     keep_alive.poll_event(cx).map(|e| Some(Ok(e)))
142                 } else {
143                     Poll::Pending
144                 }
145             }
146             Poll::Ready(Some(Ok(event))) => {
147                 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
148                     keep_alive.reset();
149                 }
150                 Poll::Ready(Some(Ok(event.finalize())))
151             }
152             Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
153             Poll::Ready(None) => Poll::Ready(None),
154         }
155     }
156 
poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>>157     fn poll_trailers(
158         self: Pin<&mut Self>,
159         _cx: &mut Context<'_>,
160     ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
161         Poll::Ready(Ok(None))
162     }
163 }
164 
165 /// Server-sent event
166 #[derive(Debug, Default, Clone)]
167 #[must_use]
168 pub struct Event {
169     buffer: BytesMut,
170     flags: EventFlags,
171 }
172 
173 impl Event {
174     /// Set the event's data data field(s) (`data:<content>`)
175     ///
176     /// Newlines in `data` will automatically be broken across `data:` fields.
177     ///
178     /// This corresponds to [`MessageEvent`'s data field].
179     ///
180     /// Note that events with an empty data field will be ignored by the browser.
181     ///
182     /// # Panics
183     ///
184     /// - Panics if `data` contains any carriage returns, as they cannot be transmitted over SSE.
185     /// - Panics if `data` or `json_data` have already been called.
186     ///
187     /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
data<T>(mut self, data: T) -> Event where T: AsRef<str>,188     pub fn data<T>(mut self, data: T) -> Event
189     where
190         T: AsRef<str>,
191     {
192         if self.flags.contains(EventFlags::HAS_DATA) {
193             panic!("Called `EventBuilder::data` multiple times");
194         }
195 
196         for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
197             self.field("data", line);
198         }
199 
200         self.flags.insert(EventFlags::HAS_DATA);
201 
202         self
203     }
204 
205     /// Set the event's data field to a value serialized as unformatted JSON (`data:<content>`).
206     ///
207     /// This corresponds to [`MessageEvent`'s data field].
208     ///
209     /// # Panics
210     ///
211     /// Panics if `data` or `json_data` have already been called.
212     ///
213     /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
214     #[cfg(feature = "json")]
json_data<T>(mut self, data: T) -> serde_json::Result<Event> where T: serde::Serialize,215     pub fn json_data<T>(mut self, data: T) -> serde_json::Result<Event>
216     where
217         T: serde::Serialize,
218     {
219         if self.flags.contains(EventFlags::HAS_DATA) {
220             panic!("Called `EventBuilder::json_data` multiple times");
221         }
222 
223         self.buffer.extend_from_slice(b"data:");
224         serde_json::to_writer((&mut self.buffer).writer(), &data)?;
225         self.buffer.put_u8(b'\n');
226 
227         self.flags.insert(EventFlags::HAS_DATA);
228 
229         Ok(self)
230     }
231 
232     /// Set the event's comment field (`:<comment-text>`).
233     ///
234     /// This field will be ignored by most SSE clients.
235     ///
236     /// Unlike other functions, this function can be called multiple times to add many comments.
237     ///
238     /// # Panics
239     ///
240     /// Panics if `comment` contains any newlines or carriage returns, as they are not allowed in
241     /// comments.
comment<T>(mut self, comment: T) -> Event where T: AsRef<str>,242     pub fn comment<T>(mut self, comment: T) -> Event
243     where
244         T: AsRef<str>,
245     {
246         self.field("", comment.as_ref());
247         self
248     }
249 
250     /// Set the event's name field (`event:<event-name>`).
251     ///
252     /// This corresponds to the `type` parameter given when calling `addEventListener` on an
253     /// [`EventSource`]. For example, `.event("update")` should correspond to
254     /// `.addEventListener("update", ...)`. If no event type is given, browsers will fire a
255     /// [`message` event] instead.
256     ///
257     /// [`EventSource`]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource
258     /// [`message` event]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource/message_event
259     ///
260     /// # Panics
261     ///
262     /// - Panics if `event` contains any newlines or carriage returns.
263     /// - Panics if this function has already been called on this event.
event<T>(mut self, event: T) -> Event where T: AsRef<str>,264     pub fn event<T>(mut self, event: T) -> Event
265     where
266         T: AsRef<str>,
267     {
268         if self.flags.contains(EventFlags::HAS_EVENT) {
269             panic!("Called `EventBuilder::event` multiple times");
270         }
271         self.flags.insert(EventFlags::HAS_EVENT);
272 
273         self.field("event", event.as_ref());
274 
275         self
276     }
277 
278     /// Set the event's retry timeout field (`retry:<timeout>`).
279     ///
280     /// This sets how long clients will wait before reconnecting if they are disconnected from the
281     /// SSE endpoint. Note that this is just a hint: clients are free to wait for longer if they
282     /// wish, such as if they implement exponential backoff.
283     ///
284     /// # Panics
285     ///
286     /// Panics if this function has already been called on this event.
retry(mut self, duration: Duration) -> Event287     pub fn retry(mut self, duration: Duration) -> Event {
288         if self.flags.contains(EventFlags::HAS_RETRY) {
289             panic!("Called `EventBuilder::retry` multiple times");
290         }
291         self.flags.insert(EventFlags::HAS_RETRY);
292 
293         self.buffer.extend_from_slice(b"retry:");
294 
295         let secs = duration.as_secs();
296         let millis = duration.subsec_millis();
297 
298         if secs > 0 {
299             // format seconds
300             self.buffer
301                 .extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
302 
303             // pad milliseconds
304             if millis < 10 {
305                 self.buffer.extend_from_slice(b"00");
306             } else if millis < 100 {
307                 self.buffer.extend_from_slice(b"0");
308             }
309         }
310 
311         // format milliseconds
312         self.buffer
313             .extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
314 
315         self.buffer.put_u8(b'\n');
316 
317         self
318     }
319 
320     /// Set the event's identifier field (`id:<identifier>`).
321     ///
322     /// This corresponds to [`MessageEvent`'s `lastEventId` field]. If no ID is in the event itself,
323     /// the browser will set that field to the last known message ID, starting with the empty
324     /// string.
325     ///
326     /// [`MessageEvent`'s `lastEventId` field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/lastEventId
327     ///
328     /// # Panics
329     ///
330     /// - Panics if `id` contains any newlines, carriage returns or null characters.
331     /// - Panics if this function has already been called on this event.
id<T>(mut self, id: T) -> Event where T: AsRef<str>,332     pub fn id<T>(mut self, id: T) -> Event
333     where
334         T: AsRef<str>,
335     {
336         if self.flags.contains(EventFlags::HAS_ID) {
337             panic!("Called `EventBuilder::id` multiple times");
338         }
339         self.flags.insert(EventFlags::HAS_ID);
340 
341         let id = id.as_ref().as_bytes();
342         assert_eq!(
343             memchr::memchr(b'\0', id),
344             None,
345             "Event ID cannot contain null characters",
346         );
347 
348         self.field("id", id);
349         self
350     }
351 
field(&mut self, name: &str, value: impl AsRef<[u8]>)352     fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
353         let value = value.as_ref();
354         assert_eq!(
355             memchr::memchr2(b'\r', b'\n', value),
356             None,
357             "SSE field value cannot contain newlines or carriage returns",
358         );
359         self.buffer.extend_from_slice(name.as_bytes());
360         self.buffer.put_u8(b':');
361         // Prevent values that start with spaces having that space stripped
362         if value.starts_with(b" ") {
363             self.buffer.put_u8(b' ');
364         }
365         self.buffer.extend_from_slice(value);
366         self.buffer.put_u8(b'\n');
367     }
368 
finalize(mut self) -> Bytes369     fn finalize(mut self) -> Bytes {
370         self.buffer.put_u8(b'\n');
371         self.buffer.freeze()
372     }
373 }
374 
375 bitflags::bitflags! {
376     #[derive(Default)]
377     struct EventFlags: u8 {
378         const HAS_DATA  = 0b0001;
379         const HAS_EVENT = 0b0010;
380         const HAS_RETRY = 0b0100;
381         const HAS_ID    = 0b1000;
382     }
383 }
384 
385 /// Configure the interval between keep-alive messages, the content
386 /// of each message, and the associated stream.
387 #[derive(Debug, Clone)]
388 #[must_use]
389 pub struct KeepAlive {
390     event: Bytes,
391     max_interval: Duration,
392 }
393 
394 impl KeepAlive {
395     /// Create a new `KeepAlive`.
new() -> Self396     pub fn new() -> Self {
397         Self {
398             event: Bytes::from_static(b":\n\n"),
399             max_interval: Duration::from_secs(15),
400         }
401     }
402 
403     /// Customize the interval between keep-alive messages.
404     ///
405     /// Default is 15 seconds.
interval(mut self, time: Duration) -> Self406     pub fn interval(mut self, time: Duration) -> Self {
407         self.max_interval = time;
408         self
409     }
410 
411     /// Customize the text of the keep-alive message.
412     ///
413     /// Default is an empty comment.
414     ///
415     /// # Panics
416     ///
417     /// Panics if `text` contains any newline or carriage returns, as they are not allowed in SSE
418     /// comments.
text<I>(self, text: I) -> Self where I: AsRef<str>,419     pub fn text<I>(self, text: I) -> Self
420     where
421         I: AsRef<str>,
422     {
423         self.event(Event::default().comment(text))
424     }
425 
426     /// Customize the event of the keep-alive message.
427     ///
428     /// Default is an empty comment.
429     ///
430     /// # Panics
431     ///
432     /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
433     /// comments.
event(mut self, event: Event) -> Self434     pub fn event(mut self, event: Event) -> Self {
435         self.event = event.finalize();
436         self
437     }
438 }
439 
440 impl Default for KeepAlive {
default() -> Self441     fn default() -> Self {
442         Self::new()
443     }
444 }
445 
446 pin_project! {
447     #[derive(Debug)]
448     struct KeepAliveStream {
449         keep_alive: KeepAlive,
450         #[pin]
451         alive_timer: Sleep,
452     }
453 }
454 
455 impl KeepAliveStream {
new(keep_alive: KeepAlive) -> Self456     fn new(keep_alive: KeepAlive) -> Self {
457         Self {
458             alive_timer: tokio::time::sleep(keep_alive.max_interval),
459             keep_alive,
460         }
461     }
462 
reset(self: Pin<&mut Self>)463     fn reset(self: Pin<&mut Self>) {
464         let this = self.project();
465         this.alive_timer
466             .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
467     }
468 
poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes>469     fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
470         let this = self.as_mut().project();
471 
472         ready!(this.alive_timer.poll(cx));
473 
474         let event = this.keep_alive.event.clone();
475 
476         self.reset();
477 
478         Poll::Ready(event)
479     }
480 }
481 
memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_>482 fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
483     MemchrSplit {
484         needle,
485         haystack: Some(haystack),
486     }
487 }
488 
489 struct MemchrSplit<'a> {
490     needle: u8,
491     haystack: Option<&'a [u8]>,
492 }
493 
494 impl<'a> Iterator for MemchrSplit<'a> {
495     type Item = &'a [u8];
next(&mut self) -> Option<Self::Item>496     fn next(&mut self) -> Option<Self::Item> {
497         let haystack = self.haystack?;
498         if let Some(pos) = memchr::memchr(self.needle, haystack) {
499             let (front, back) = haystack.split_at(pos);
500             self.haystack = Some(&back[1..]);
501             Some(front)
502         } else {
503             self.haystack.take()
504         }
505     }
506 }
507 
508 #[cfg(test)]
509 mod tests {
510     use super::*;
511     use crate::{routing::get, test_helpers::*, Router};
512     use futures_util::stream;
513     use std::{collections::HashMap, convert::Infallible};
514     use tokio_stream::StreamExt as _;
515 
516     #[test]
leading_space_is_not_stripped()517     fn leading_space_is_not_stripped() {
518         let no_leading_space = Event::default().data("\tfoobar");
519         assert_eq!(&*no_leading_space.finalize(), b"data:\tfoobar\n\n");
520 
521         let leading_space = Event::default().data(" foobar");
522         assert_eq!(&*leading_space.finalize(), b"data:  foobar\n\n");
523     }
524 
525     #[crate::test]
basic()526     async fn basic() {
527         let app = Router::new().route(
528             "/",
529             get(|| async {
530                 let stream = stream::iter(vec![
531                     Event::default().data("one").comment("this is a comment"),
532                     Event::default()
533                         .json_data(serde_json::json!({ "foo": "bar" }))
534                         .unwrap(),
535                     Event::default()
536                         .event("three")
537                         .retry(Duration::from_secs(30))
538                         .id("unique-id"),
539                 ])
540                 .map(Ok::<_, Infallible>);
541                 Sse::new(stream)
542             }),
543         );
544 
545         let client = TestClient::new(app);
546         let mut stream = client.get("/").send().await;
547 
548         assert_eq!(stream.headers()["content-type"], "text/event-stream");
549         assert_eq!(stream.headers()["cache-control"], "no-cache");
550 
551         let event_fields = parse_event(&stream.chunk_text().await.unwrap());
552         assert_eq!(event_fields.get("data").unwrap(), "one");
553         assert_eq!(event_fields.get("comment").unwrap(), "this is a comment");
554 
555         let event_fields = parse_event(&stream.chunk_text().await.unwrap());
556         assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}");
557         assert!(event_fields.get("comment").is_none());
558 
559         let event_fields = parse_event(&stream.chunk_text().await.unwrap());
560         assert_eq!(event_fields.get("event").unwrap(), "three");
561         assert_eq!(event_fields.get("retry").unwrap(), "30000");
562         assert_eq!(event_fields.get("id").unwrap(), "unique-id");
563         assert!(event_fields.get("comment").is_none());
564 
565         assert!(stream.chunk_text().await.is_none());
566     }
567 
568     #[tokio::test(start_paused = true)]
keep_alive()569     async fn keep_alive() {
570         const DELAY: Duration = Duration::from_secs(5);
571 
572         let app = Router::new().route(
573             "/",
574             get(|| async {
575                 let stream = stream::repeat_with(|| Event::default().data("msg"))
576                     .map(Ok::<_, Infallible>)
577                     .throttle(DELAY);
578 
579                 Sse::new(stream).keep_alive(
580                     KeepAlive::new()
581                         .interval(Duration::from_secs(1))
582                         .text("keep-alive-text"),
583                 )
584             }),
585         );
586 
587         let client = TestClient::new(app);
588         let mut stream = client.get("/").send().await;
589 
590         for _ in 0..5 {
591             // first message should be an event
592             let event_fields = parse_event(&stream.chunk_text().await.unwrap());
593             assert_eq!(event_fields.get("data").unwrap(), "msg");
594 
595             // then 4 seconds of keep-alive messages
596             for _ in 0..4 {
597                 tokio::time::sleep(Duration::from_secs(1)).await;
598                 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
599                 assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
600             }
601         }
602     }
603 
604     #[tokio::test(start_paused = true)]
keep_alive_ends_when_the_stream_ends()605     async fn keep_alive_ends_when_the_stream_ends() {
606         const DELAY: Duration = Duration::from_secs(5);
607 
608         let app = Router::new().route(
609             "/",
610             get(|| async {
611                 let stream = stream::repeat_with(|| Event::default().data("msg"))
612                     .map(Ok::<_, Infallible>)
613                     .throttle(DELAY)
614                     .take(2);
615 
616                 Sse::new(stream).keep_alive(
617                     KeepAlive::new()
618                         .interval(Duration::from_secs(1))
619                         .text("keep-alive-text"),
620                 )
621             }),
622         );
623 
624         let client = TestClient::new(app);
625         let mut stream = client.get("/").send().await;
626 
627         // first message should be an event
628         let event_fields = parse_event(&stream.chunk_text().await.unwrap());
629         assert_eq!(event_fields.get("data").unwrap(), "msg");
630 
631         // then 4 seconds of keep-alive messages
632         for _ in 0..4 {
633             tokio::time::sleep(Duration::from_secs(1)).await;
634             let event_fields = parse_event(&stream.chunk_text().await.unwrap());
635             assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
636         }
637 
638         // then the last event
639         let event_fields = parse_event(&stream.chunk_text().await.unwrap());
640         assert_eq!(event_fields.get("data").unwrap(), "msg");
641 
642         // then no more events or keep-alive messages
643         assert!(stream.chunk_text().await.is_none());
644     }
645 
parse_event(payload: &str) -> HashMap<String, String>646     fn parse_event(payload: &str) -> HashMap<String, String> {
647         let mut fields = HashMap::new();
648 
649         let mut lines = payload.lines().peekable();
650         while let Some(line) = lines.next() {
651             if line.is_empty() {
652                 assert!(lines.next().is_none());
653                 break;
654             }
655 
656             let (mut key, value) = line.split_once(':').unwrap();
657             let value = value.trim();
658             if key.is_empty() {
659                 key = "comment";
660             }
661             fields.insert(key.to_owned(), value.to_owned());
662         }
663 
664         fields
665     }
666 
667     #[test]
memchr_spliting()668     fn memchr_spliting() {
669         assert_eq!(
670             memchr_split(2, &[]).collect::<Vec<_>>(),
671             [&[]] as [&[u8]; 1]
672         );
673         assert_eq!(
674             memchr_split(2, &[2]).collect::<Vec<_>>(),
675             [&[], &[]] as [&[u8]; 2]
676         );
677         assert_eq!(
678             memchr_split(2, &[1]).collect::<Vec<_>>(),
679             [&[1]] as [&[u8]; 1]
680         );
681         assert_eq!(
682             memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
683             [&[1], &[]] as [&[u8]; 2]
684         );
685         assert_eq!(
686             memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
687             [&[], &[1]] as [&[u8]; 2]
688         );
689         assert_eq!(
690             memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
691             [&[1], &[], &[1]] as [&[u8]; 3]
692         );
693     }
694 }
695