1 //! Server-Sent Events (SSE) responses.
2 //!
3 //! # Example
4 //!
5 //! ```
6 //! use axum::{
7 //! Router,
8 //! routing::get,
9 //! response::sse::{Event, KeepAlive, Sse},
10 //! };
11 //! use std::{time::Duration, convert::Infallible};
12 //! use tokio_stream::StreamExt as _ ;
13 //! use futures_util::stream::{self, Stream};
14 //!
15 //! let app = Router::new().route("/sse", get(sse_handler));
16 //!
17 //! async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
18 //! // A `Stream` that repeats an event every second
19 //! let stream = stream::repeat_with(|| Event::default().data("hi!"))
20 //! .map(Ok)
21 //! .throttle(Duration::from_secs(1));
22 //!
23 //! Sse::new(stream).keep_alive(KeepAlive::default())
24 //! }
25 //! # async {
26 //! # hyper::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
27 //! # };
28 //! ```
29
30 use crate::{
31 body::{Bytes, HttpBody},
32 BoxError,
33 };
34 use axum_core::{
35 body,
36 response::{IntoResponse, Response},
37 };
38 use bytes::{BufMut, BytesMut};
39 use futures_util::{
40 ready,
41 stream::{Stream, TryStream},
42 };
43 use pin_project_lite::pin_project;
44 use std::{
45 fmt,
46 future::Future,
47 pin::Pin,
48 task::{Context, Poll},
49 time::Duration,
50 };
51 use sync_wrapper::SyncWrapper;
52 use tokio::time::Sleep;
53
54 /// An SSE response
55 #[derive(Clone)]
56 #[must_use]
57 pub struct Sse<S> {
58 stream: S,
59 keep_alive: Option<KeepAlive>,
60 }
61
62 impl<S> Sse<S> {
63 /// Create a new [`Sse`] response that will respond with the given stream of
64 /// [`Event`]s.
65 ///
66 /// See the [module docs](self) for more details.
new(stream: S) -> Self where S: TryStream<Ok = Event> + Send + 'static, S::Error: Into<BoxError>,67 pub fn new(stream: S) -> Self
68 where
69 S: TryStream<Ok = Event> + Send + 'static,
70 S::Error: Into<BoxError>,
71 {
72 Sse {
73 stream,
74 keep_alive: None,
75 }
76 }
77
78 /// Configure the interval between keep-alive messages.
79 ///
80 /// Defaults to no keep-alive messages.
keep_alive(mut self, keep_alive: KeepAlive) -> Self81 pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
82 self.keep_alive = Some(keep_alive);
83 self
84 }
85 }
86
87 impl<S> fmt::Debug for Sse<S> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 f.debug_struct("Sse")
90 .field("stream", &format_args!("{}", std::any::type_name::<S>()))
91 .field("keep_alive", &self.keep_alive)
92 .finish()
93 }
94 }
95
96 impl<S, E> IntoResponse for Sse<S>
97 where
98 S: Stream<Item = Result<Event, E>> + Send + 'static,
99 E: Into<BoxError>,
100 {
into_response(self) -> Response101 fn into_response(self) -> Response {
102 (
103 [
104 (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()),
105 (http::header::CACHE_CONTROL, "no-cache"),
106 ],
107 body::boxed(Body {
108 event_stream: SyncWrapper::new(self.stream),
109 keep_alive: self.keep_alive.map(KeepAliveStream::new),
110 }),
111 )
112 .into_response()
113 }
114 }
115
116 pin_project! {
117 struct Body<S> {
118 #[pin]
119 event_stream: SyncWrapper<S>,
120 #[pin]
121 keep_alive: Option<KeepAliveStream>,
122 }
123 }
124
125 impl<S, E> HttpBody for Body<S>
126 where
127 S: Stream<Item = Result<Event, E>>,
128 {
129 type Data = Bytes;
130 type Error = E;
131
poll_data( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Data, Self::Error>>>132 fn poll_data(
133 self: Pin<&mut Self>,
134 cx: &mut Context<'_>,
135 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
136 let this = self.project();
137
138 match this.event_stream.get_pin_mut().poll_next(cx) {
139 Poll::Pending => {
140 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
141 keep_alive.poll_event(cx).map(|e| Some(Ok(e)))
142 } else {
143 Poll::Pending
144 }
145 }
146 Poll::Ready(Some(Ok(event))) => {
147 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
148 keep_alive.reset();
149 }
150 Poll::Ready(Some(Ok(event.finalize())))
151 }
152 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
153 Poll::Ready(None) => Poll::Ready(None),
154 }
155 }
156
poll_trailers( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>>157 fn poll_trailers(
158 self: Pin<&mut Self>,
159 _cx: &mut Context<'_>,
160 ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
161 Poll::Ready(Ok(None))
162 }
163 }
164
165 /// Server-sent event
166 #[derive(Debug, Default, Clone)]
167 #[must_use]
168 pub struct Event {
169 buffer: BytesMut,
170 flags: EventFlags,
171 }
172
173 impl Event {
174 /// Set the event's data data field(s) (`data:<content>`)
175 ///
176 /// Newlines in `data` will automatically be broken across `data:` fields.
177 ///
178 /// This corresponds to [`MessageEvent`'s data field].
179 ///
180 /// Note that events with an empty data field will be ignored by the browser.
181 ///
182 /// # Panics
183 ///
184 /// - Panics if `data` contains any carriage returns, as they cannot be transmitted over SSE.
185 /// - Panics if `data` or `json_data` have already been called.
186 ///
187 /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
data<T>(mut self, data: T) -> Event where T: AsRef<str>,188 pub fn data<T>(mut self, data: T) -> Event
189 where
190 T: AsRef<str>,
191 {
192 if self.flags.contains(EventFlags::HAS_DATA) {
193 panic!("Called `EventBuilder::data` multiple times");
194 }
195
196 for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
197 self.field("data", line);
198 }
199
200 self.flags.insert(EventFlags::HAS_DATA);
201
202 self
203 }
204
205 /// Set the event's data field to a value serialized as unformatted JSON (`data:<content>`).
206 ///
207 /// This corresponds to [`MessageEvent`'s data field].
208 ///
209 /// # Panics
210 ///
211 /// Panics if `data` or `json_data` have already been called.
212 ///
213 /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
214 #[cfg(feature = "json")]
json_data<T>(mut self, data: T) -> serde_json::Result<Event> where T: serde::Serialize,215 pub fn json_data<T>(mut self, data: T) -> serde_json::Result<Event>
216 where
217 T: serde::Serialize,
218 {
219 if self.flags.contains(EventFlags::HAS_DATA) {
220 panic!("Called `EventBuilder::json_data` multiple times");
221 }
222
223 self.buffer.extend_from_slice(b"data:");
224 serde_json::to_writer((&mut self.buffer).writer(), &data)?;
225 self.buffer.put_u8(b'\n');
226
227 self.flags.insert(EventFlags::HAS_DATA);
228
229 Ok(self)
230 }
231
232 /// Set the event's comment field (`:<comment-text>`).
233 ///
234 /// This field will be ignored by most SSE clients.
235 ///
236 /// Unlike other functions, this function can be called multiple times to add many comments.
237 ///
238 /// # Panics
239 ///
240 /// Panics if `comment` contains any newlines or carriage returns, as they are not allowed in
241 /// comments.
comment<T>(mut self, comment: T) -> Event where T: AsRef<str>,242 pub fn comment<T>(mut self, comment: T) -> Event
243 where
244 T: AsRef<str>,
245 {
246 self.field("", comment.as_ref());
247 self
248 }
249
250 /// Set the event's name field (`event:<event-name>`).
251 ///
252 /// This corresponds to the `type` parameter given when calling `addEventListener` on an
253 /// [`EventSource`]. For example, `.event("update")` should correspond to
254 /// `.addEventListener("update", ...)`. If no event type is given, browsers will fire a
255 /// [`message` event] instead.
256 ///
257 /// [`EventSource`]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource
258 /// [`message` event]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource/message_event
259 ///
260 /// # Panics
261 ///
262 /// - Panics if `event` contains any newlines or carriage returns.
263 /// - Panics if this function has already been called on this event.
event<T>(mut self, event: T) -> Event where T: AsRef<str>,264 pub fn event<T>(mut self, event: T) -> Event
265 where
266 T: AsRef<str>,
267 {
268 if self.flags.contains(EventFlags::HAS_EVENT) {
269 panic!("Called `EventBuilder::event` multiple times");
270 }
271 self.flags.insert(EventFlags::HAS_EVENT);
272
273 self.field("event", event.as_ref());
274
275 self
276 }
277
278 /// Set the event's retry timeout field (`retry:<timeout>`).
279 ///
280 /// This sets how long clients will wait before reconnecting if they are disconnected from the
281 /// SSE endpoint. Note that this is just a hint: clients are free to wait for longer if they
282 /// wish, such as if they implement exponential backoff.
283 ///
284 /// # Panics
285 ///
286 /// Panics if this function has already been called on this event.
retry(mut self, duration: Duration) -> Event287 pub fn retry(mut self, duration: Duration) -> Event {
288 if self.flags.contains(EventFlags::HAS_RETRY) {
289 panic!("Called `EventBuilder::retry` multiple times");
290 }
291 self.flags.insert(EventFlags::HAS_RETRY);
292
293 self.buffer.extend_from_slice(b"retry:");
294
295 let secs = duration.as_secs();
296 let millis = duration.subsec_millis();
297
298 if secs > 0 {
299 // format seconds
300 self.buffer
301 .extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
302
303 // pad milliseconds
304 if millis < 10 {
305 self.buffer.extend_from_slice(b"00");
306 } else if millis < 100 {
307 self.buffer.extend_from_slice(b"0");
308 }
309 }
310
311 // format milliseconds
312 self.buffer
313 .extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
314
315 self.buffer.put_u8(b'\n');
316
317 self
318 }
319
320 /// Set the event's identifier field (`id:<identifier>`).
321 ///
322 /// This corresponds to [`MessageEvent`'s `lastEventId` field]. If no ID is in the event itself,
323 /// the browser will set that field to the last known message ID, starting with the empty
324 /// string.
325 ///
326 /// [`MessageEvent`'s `lastEventId` field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/lastEventId
327 ///
328 /// # Panics
329 ///
330 /// - Panics if `id` contains any newlines, carriage returns or null characters.
331 /// - Panics if this function has already been called on this event.
id<T>(mut self, id: T) -> Event where T: AsRef<str>,332 pub fn id<T>(mut self, id: T) -> Event
333 where
334 T: AsRef<str>,
335 {
336 if self.flags.contains(EventFlags::HAS_ID) {
337 panic!("Called `EventBuilder::id` multiple times");
338 }
339 self.flags.insert(EventFlags::HAS_ID);
340
341 let id = id.as_ref().as_bytes();
342 assert_eq!(
343 memchr::memchr(b'\0', id),
344 None,
345 "Event ID cannot contain null characters",
346 );
347
348 self.field("id", id);
349 self
350 }
351
field(&mut self, name: &str, value: impl AsRef<[u8]>)352 fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
353 let value = value.as_ref();
354 assert_eq!(
355 memchr::memchr2(b'\r', b'\n', value),
356 None,
357 "SSE field value cannot contain newlines or carriage returns",
358 );
359 self.buffer.extend_from_slice(name.as_bytes());
360 self.buffer.put_u8(b':');
361 // Prevent values that start with spaces having that space stripped
362 if value.starts_with(b" ") {
363 self.buffer.put_u8(b' ');
364 }
365 self.buffer.extend_from_slice(value);
366 self.buffer.put_u8(b'\n');
367 }
368
finalize(mut self) -> Bytes369 fn finalize(mut self) -> Bytes {
370 self.buffer.put_u8(b'\n');
371 self.buffer.freeze()
372 }
373 }
374
375 bitflags::bitflags! {
376 #[derive(Default)]
377 struct EventFlags: u8 {
378 const HAS_DATA = 0b0001;
379 const HAS_EVENT = 0b0010;
380 const HAS_RETRY = 0b0100;
381 const HAS_ID = 0b1000;
382 }
383 }
384
385 /// Configure the interval between keep-alive messages, the content
386 /// of each message, and the associated stream.
387 #[derive(Debug, Clone)]
388 #[must_use]
389 pub struct KeepAlive {
390 event: Bytes,
391 max_interval: Duration,
392 }
393
394 impl KeepAlive {
395 /// Create a new `KeepAlive`.
new() -> Self396 pub fn new() -> Self {
397 Self {
398 event: Bytes::from_static(b":\n\n"),
399 max_interval: Duration::from_secs(15),
400 }
401 }
402
403 /// Customize the interval between keep-alive messages.
404 ///
405 /// Default is 15 seconds.
interval(mut self, time: Duration) -> Self406 pub fn interval(mut self, time: Duration) -> Self {
407 self.max_interval = time;
408 self
409 }
410
411 /// Customize the text of the keep-alive message.
412 ///
413 /// Default is an empty comment.
414 ///
415 /// # Panics
416 ///
417 /// Panics if `text` contains any newline or carriage returns, as they are not allowed in SSE
418 /// comments.
text<I>(self, text: I) -> Self where I: AsRef<str>,419 pub fn text<I>(self, text: I) -> Self
420 where
421 I: AsRef<str>,
422 {
423 self.event(Event::default().comment(text))
424 }
425
426 /// Customize the event of the keep-alive message.
427 ///
428 /// Default is an empty comment.
429 ///
430 /// # Panics
431 ///
432 /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
433 /// comments.
event(mut self, event: Event) -> Self434 pub fn event(mut self, event: Event) -> Self {
435 self.event = event.finalize();
436 self
437 }
438 }
439
440 impl Default for KeepAlive {
default() -> Self441 fn default() -> Self {
442 Self::new()
443 }
444 }
445
446 pin_project! {
447 #[derive(Debug)]
448 struct KeepAliveStream {
449 keep_alive: KeepAlive,
450 #[pin]
451 alive_timer: Sleep,
452 }
453 }
454
455 impl KeepAliveStream {
new(keep_alive: KeepAlive) -> Self456 fn new(keep_alive: KeepAlive) -> Self {
457 Self {
458 alive_timer: tokio::time::sleep(keep_alive.max_interval),
459 keep_alive,
460 }
461 }
462
reset(self: Pin<&mut Self>)463 fn reset(self: Pin<&mut Self>) {
464 let this = self.project();
465 this.alive_timer
466 .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
467 }
468
poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes>469 fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
470 let this = self.as_mut().project();
471
472 ready!(this.alive_timer.poll(cx));
473
474 let event = this.keep_alive.event.clone();
475
476 self.reset();
477
478 Poll::Ready(event)
479 }
480 }
481
memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_>482 fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
483 MemchrSplit {
484 needle,
485 haystack: Some(haystack),
486 }
487 }
488
489 struct MemchrSplit<'a> {
490 needle: u8,
491 haystack: Option<&'a [u8]>,
492 }
493
494 impl<'a> Iterator for MemchrSplit<'a> {
495 type Item = &'a [u8];
next(&mut self) -> Option<Self::Item>496 fn next(&mut self) -> Option<Self::Item> {
497 let haystack = self.haystack?;
498 if let Some(pos) = memchr::memchr(self.needle, haystack) {
499 let (front, back) = haystack.split_at(pos);
500 self.haystack = Some(&back[1..]);
501 Some(front)
502 } else {
503 self.haystack.take()
504 }
505 }
506 }
507
508 #[cfg(test)]
509 mod tests {
510 use super::*;
511 use crate::{routing::get, test_helpers::*, Router};
512 use futures_util::stream;
513 use std::{collections::HashMap, convert::Infallible};
514 use tokio_stream::StreamExt as _;
515
516 #[test]
leading_space_is_not_stripped()517 fn leading_space_is_not_stripped() {
518 let no_leading_space = Event::default().data("\tfoobar");
519 assert_eq!(&*no_leading_space.finalize(), b"data:\tfoobar\n\n");
520
521 let leading_space = Event::default().data(" foobar");
522 assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n");
523 }
524
525 #[crate::test]
basic()526 async fn basic() {
527 let app = Router::new().route(
528 "/",
529 get(|| async {
530 let stream = stream::iter(vec![
531 Event::default().data("one").comment("this is a comment"),
532 Event::default()
533 .json_data(serde_json::json!({ "foo": "bar" }))
534 .unwrap(),
535 Event::default()
536 .event("three")
537 .retry(Duration::from_secs(30))
538 .id("unique-id"),
539 ])
540 .map(Ok::<_, Infallible>);
541 Sse::new(stream)
542 }),
543 );
544
545 let client = TestClient::new(app);
546 let mut stream = client.get("/").send().await;
547
548 assert_eq!(stream.headers()["content-type"], "text/event-stream");
549 assert_eq!(stream.headers()["cache-control"], "no-cache");
550
551 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
552 assert_eq!(event_fields.get("data").unwrap(), "one");
553 assert_eq!(event_fields.get("comment").unwrap(), "this is a comment");
554
555 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
556 assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}");
557 assert!(event_fields.get("comment").is_none());
558
559 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
560 assert_eq!(event_fields.get("event").unwrap(), "three");
561 assert_eq!(event_fields.get("retry").unwrap(), "30000");
562 assert_eq!(event_fields.get("id").unwrap(), "unique-id");
563 assert!(event_fields.get("comment").is_none());
564
565 assert!(stream.chunk_text().await.is_none());
566 }
567
568 #[tokio::test(start_paused = true)]
keep_alive()569 async fn keep_alive() {
570 const DELAY: Duration = Duration::from_secs(5);
571
572 let app = Router::new().route(
573 "/",
574 get(|| async {
575 let stream = stream::repeat_with(|| Event::default().data("msg"))
576 .map(Ok::<_, Infallible>)
577 .throttle(DELAY);
578
579 Sse::new(stream).keep_alive(
580 KeepAlive::new()
581 .interval(Duration::from_secs(1))
582 .text("keep-alive-text"),
583 )
584 }),
585 );
586
587 let client = TestClient::new(app);
588 let mut stream = client.get("/").send().await;
589
590 for _ in 0..5 {
591 // first message should be an event
592 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
593 assert_eq!(event_fields.get("data").unwrap(), "msg");
594
595 // then 4 seconds of keep-alive messages
596 for _ in 0..4 {
597 tokio::time::sleep(Duration::from_secs(1)).await;
598 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
599 assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
600 }
601 }
602 }
603
604 #[tokio::test(start_paused = true)]
keep_alive_ends_when_the_stream_ends()605 async fn keep_alive_ends_when_the_stream_ends() {
606 const DELAY: Duration = Duration::from_secs(5);
607
608 let app = Router::new().route(
609 "/",
610 get(|| async {
611 let stream = stream::repeat_with(|| Event::default().data("msg"))
612 .map(Ok::<_, Infallible>)
613 .throttle(DELAY)
614 .take(2);
615
616 Sse::new(stream).keep_alive(
617 KeepAlive::new()
618 .interval(Duration::from_secs(1))
619 .text("keep-alive-text"),
620 )
621 }),
622 );
623
624 let client = TestClient::new(app);
625 let mut stream = client.get("/").send().await;
626
627 // first message should be an event
628 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
629 assert_eq!(event_fields.get("data").unwrap(), "msg");
630
631 // then 4 seconds of keep-alive messages
632 for _ in 0..4 {
633 tokio::time::sleep(Duration::from_secs(1)).await;
634 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
635 assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
636 }
637
638 // then the last event
639 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
640 assert_eq!(event_fields.get("data").unwrap(), "msg");
641
642 // then no more events or keep-alive messages
643 assert!(stream.chunk_text().await.is_none());
644 }
645
parse_event(payload: &str) -> HashMap<String, String>646 fn parse_event(payload: &str) -> HashMap<String, String> {
647 let mut fields = HashMap::new();
648
649 let mut lines = payload.lines().peekable();
650 while let Some(line) = lines.next() {
651 if line.is_empty() {
652 assert!(lines.next().is_none());
653 break;
654 }
655
656 let (mut key, value) = line.split_once(':').unwrap();
657 let value = value.trim();
658 if key.is_empty() {
659 key = "comment";
660 }
661 fields.insert(key.to_owned(), value.to_owned());
662 }
663
664 fields
665 }
666
667 #[test]
memchr_spliting()668 fn memchr_spliting() {
669 assert_eq!(
670 memchr_split(2, &[]).collect::<Vec<_>>(),
671 [&[]] as [&[u8]; 1]
672 );
673 assert_eq!(
674 memchr_split(2, &[2]).collect::<Vec<_>>(),
675 [&[], &[]] as [&[u8]; 2]
676 );
677 assert_eq!(
678 memchr_split(2, &[1]).collect::<Vec<_>>(),
679 [&[1]] as [&[u8]; 1]
680 );
681 assert_eq!(
682 memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
683 [&[1], &[]] as [&[u8]; 2]
684 );
685 assert_eq!(
686 memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
687 [&[], &[1]] as [&[u8]; 2]
688 );
689 assert_eq!(
690 memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
691 [&[1], &[], &[1]] as [&[u8]; 3]
692 );
693 }
694 }
695