1 use crate::codec::{Decoder, Encoder};
2 
3 use futures_core::Stream;
4 use tokio::{io::ReadBuf, net::UdpSocket};
5 
6 use bytes::{BufMut, BytesMut};
7 use futures_sink::Sink;
8 use std::pin::Pin;
9 use std::task::{ready, Context, Poll};
10 use std::{
11     borrow::Borrow,
12     net::{Ipv4Addr, SocketAddr, SocketAddrV4},
13 };
14 use std::{io, mem::MaybeUninit};
15 
16 /// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using
17 /// the `Encoder` and `Decoder` traits to encode and decode frames.
18 ///
19 /// Raw UDP sockets work with datagrams, but higher-level code usually wants to
20 /// batch these into meaningful chunks, called "frames". This method layers
21 /// framing on top of this socket by using the `Encoder` and `Decoder` traits to
22 /// handle encoding and decoding of messages frames. Note that the incoming and
23 /// outgoing frame types may be distinct.
24 ///
25 /// This function returns a *single* object that is both [`Stream`] and [`Sink`];
26 /// grouping this into a single object is often useful for layering things which
27 /// require both read and write access to the underlying object.
28 ///
29 /// If you want to work more directly with the streams and sink, consider
30 /// calling [`split`] on the `UdpFramed` returned by this method, which will break
31 /// them into separate objects, allowing them to interact more easily.
32 ///
33 /// [`Stream`]: futures_core::Stream
34 /// [`Sink`]: futures_sink::Sink
35 /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split
36 #[must_use = "sinks do nothing unless polled"]
37 #[derive(Debug)]
38 pub struct UdpFramed<C, T = UdpSocket> {
39     socket: T,
40     codec: C,
41     rd: BytesMut,
42     wr: BytesMut,
43     out_addr: SocketAddr,
44     flushed: bool,
45     is_readable: bool,
46     current_addr: Option<SocketAddr>,
47 }
48 
49 const INITIAL_RD_CAPACITY: usize = 64 * 1024;
50 const INITIAL_WR_CAPACITY: usize = 8 * 1024;
51 
52 impl<C, T> Unpin for UdpFramed<C, T> {}
53 
54 impl<C, T> Stream for UdpFramed<C, T>
55 where
56     T: Borrow<UdpSocket>,
57     C: Decoder,
58 {
59     type Item = Result<(C::Item, SocketAddr), C::Error>;
60 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>61     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
62         let pin = self.get_mut();
63 
64         pin.rd.reserve(INITIAL_RD_CAPACITY);
65 
66         loop {
67             // Are there still bytes left in the read buffer to decode?
68             if pin.is_readable {
69                 if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? {
70                     let current_addr = pin
71                         .current_addr
72                         .expect("will always be set before this line is called");
73 
74                     return Poll::Ready(Some(Ok((frame, current_addr))));
75                 }
76 
77                 // if this line has been reached then decode has returned `None`.
78                 pin.is_readable = false;
79                 pin.rd.clear();
80             }
81 
82             // We're out of data. Try and fetch more data to decode
83             let addr = {
84                 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
85                 // transparent wrapper around `[MaybeUninit<u8>]`.
86                 let buf = unsafe { &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]) };
87                 let mut read = ReadBuf::uninit(buf);
88                 let ptr = read.filled().as_ptr();
89                 let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read));
90 
91                 assert_eq!(ptr, read.filled().as_ptr());
92                 let addr = res?;
93 
94                 // Safety: This is guaranteed to be the number of initialized (and read) bytes due
95                 // to the invariants provided by `ReadBuf::filled`.
96                 unsafe { pin.rd.advance_mut(read.filled().len()) };
97 
98                 addr
99             };
100 
101             pin.current_addr = Some(addr);
102             pin.is_readable = true;
103         }
104     }
105 }
106 
107 impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T>
108 where
109     T: Borrow<UdpSocket>,
110     C: Encoder<I>,
111 {
112     type Error = C::Error;
113 
poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>114     fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115         if !self.flushed {
116             match self.poll_flush(cx)? {
117                 Poll::Ready(()) => {}
118                 Poll::Pending => return Poll::Pending,
119             }
120         }
121 
122         Poll::Ready(Ok(()))
123     }
124 
start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error>125     fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> {
126         let (frame, out_addr) = item;
127 
128         let pin = self.get_mut();
129 
130         pin.codec.encode(frame, &mut pin.wr)?;
131         pin.out_addr = out_addr;
132         pin.flushed = false;
133 
134         Ok(())
135     }
136 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>137     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138         if self.flushed {
139             return Poll::Ready(Ok(()));
140         }
141 
142         let Self {
143             ref socket,
144             ref mut out_addr,
145             ref mut wr,
146             ..
147         } = *self;
148 
149         let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?;
150 
151         let wrote_all = n == self.wr.len();
152         self.wr.clear();
153         self.flushed = true;
154 
155         let res = if wrote_all {
156             Ok(())
157         } else {
158             Err(io::Error::new(
159                 io::ErrorKind::Other,
160                 "failed to write entire datagram to socket",
161             )
162             .into())
163         };
164 
165         Poll::Ready(res)
166     }
167 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>168     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169         ready!(self.poll_flush(cx))?;
170         Poll::Ready(Ok(()))
171     }
172 }
173 
174 impl<C, T> UdpFramed<C, T>
175 where
176     T: Borrow<UdpSocket>,
177 {
178     /// Create a new `UdpFramed` backed by the given socket and codec.
179     ///
180     /// See struct level documentation for more details.
new(socket: T, codec: C) -> UdpFramed<C, T>181     pub fn new(socket: T, codec: C) -> UdpFramed<C, T> {
182         Self {
183             socket,
184             codec,
185             out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
186             rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY),
187             wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY),
188             flushed: true,
189             is_readable: false,
190             current_addr: None,
191         }
192     }
193 
194     /// Returns a reference to the underlying I/O stream wrapped by `Framed`.
195     ///
196     /// # Note
197     ///
198     /// Care should be taken to not tamper with the underlying stream of data
199     /// coming in as it may corrupt the stream of frames otherwise being worked
200     /// with.
get_ref(&self) -> &T201     pub fn get_ref(&self) -> &T {
202         &self.socket
203     }
204 
205     /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`.
206     ///
207     /// # Note
208     ///
209     /// Care should be taken to not tamper with the underlying stream of data
210     /// coming in as it may corrupt the stream of frames otherwise being worked
211     /// with.
get_mut(&mut self) -> &mut T212     pub fn get_mut(&mut self) -> &mut T {
213         &mut self.socket
214     }
215 
216     /// Returns a reference to the underlying codec wrapped by
217     /// `Framed`.
218     ///
219     /// Note that care should be taken to not tamper with the underlying codec
220     /// as it may corrupt the stream of frames otherwise being worked with.
codec(&self) -> &C221     pub fn codec(&self) -> &C {
222         &self.codec
223     }
224 
225     /// Returns a mutable reference to the underlying codec wrapped by
226     /// `UdpFramed`.
227     ///
228     /// Note that care should be taken to not tamper with the underlying codec
229     /// as it may corrupt the stream of frames otherwise being worked with.
codec_mut(&mut self) -> &mut C230     pub fn codec_mut(&mut self) -> &mut C {
231         &mut self.codec
232     }
233 
234     /// Returns a reference to the read buffer.
read_buffer(&self) -> &BytesMut235     pub fn read_buffer(&self) -> &BytesMut {
236         &self.rd
237     }
238 
239     /// Returns a mutable reference to the read buffer.
read_buffer_mut(&mut self) -> &mut BytesMut240     pub fn read_buffer_mut(&mut self) -> &mut BytesMut {
241         &mut self.rd
242     }
243 
244     /// Consumes the `Framed`, returning its underlying I/O stream.
into_inner(self) -> T245     pub fn into_inner(self) -> T {
246         self.socket
247     }
248 }
249