1 use crate::codec::{Decoder, Encoder}; 2 3 use futures_core::Stream; 4 use tokio::{io::ReadBuf, net::UdpSocket}; 5 6 use bytes::{BufMut, BytesMut}; 7 use futures_sink::Sink; 8 use std::pin::Pin; 9 use std::task::{ready, Context, Poll}; 10 use std::{ 11 borrow::Borrow, 12 net::{Ipv4Addr, SocketAddr, SocketAddrV4}, 13 }; 14 use std::{io, mem::MaybeUninit}; 15 16 /// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using 17 /// the `Encoder` and `Decoder` traits to encode and decode frames. 18 /// 19 /// Raw UDP sockets work with datagrams, but higher-level code usually wants to 20 /// batch these into meaningful chunks, called "frames". This method layers 21 /// framing on top of this socket by using the `Encoder` and `Decoder` traits to 22 /// handle encoding and decoding of messages frames. Note that the incoming and 23 /// outgoing frame types may be distinct. 24 /// 25 /// This function returns a *single* object that is both [`Stream`] and [`Sink`]; 26 /// grouping this into a single object is often useful for layering things which 27 /// require both read and write access to the underlying object. 28 /// 29 /// If you want to work more directly with the streams and sink, consider 30 /// calling [`split`] on the `UdpFramed` returned by this method, which will break 31 /// them into separate objects, allowing them to interact more easily. 32 /// 33 /// [`Stream`]: futures_core::Stream 34 /// [`Sink`]: futures_sink::Sink 35 /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split 36 #[must_use = "sinks do nothing unless polled"] 37 #[derive(Debug)] 38 pub struct UdpFramed<C, T = UdpSocket> { 39 socket: T, 40 codec: C, 41 rd: BytesMut, 42 wr: BytesMut, 43 out_addr: SocketAddr, 44 flushed: bool, 45 is_readable: bool, 46 current_addr: Option<SocketAddr>, 47 } 48 49 const INITIAL_RD_CAPACITY: usize = 64 * 1024; 50 const INITIAL_WR_CAPACITY: usize = 8 * 1024; 51 52 impl<C, T> Unpin for UdpFramed<C, T> {} 53 54 impl<C, T> Stream for UdpFramed<C, T> 55 where 56 T: Borrow<UdpSocket>, 57 C: Decoder, 58 { 59 type Item = Result<(C::Item, SocketAddr), C::Error>; 60 poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>61 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 62 let pin = self.get_mut(); 63 64 pin.rd.reserve(INITIAL_RD_CAPACITY); 65 66 loop { 67 // Are there still bytes left in the read buffer to decode? 68 if pin.is_readable { 69 if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { 70 let current_addr = pin 71 .current_addr 72 .expect("will always be set before this line is called"); 73 74 return Poll::Ready(Some(Ok((frame, current_addr)))); 75 } 76 77 // if this line has been reached then decode has returned `None`. 78 pin.is_readable = false; 79 pin.rd.clear(); 80 } 81 82 // We're out of data. Try and fetch more data to decode 83 let addr = { 84 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a 85 // transparent wrapper around `[MaybeUninit<u8>]`. 86 let buf = unsafe { &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]) }; 87 let mut read = ReadBuf::uninit(buf); 88 let ptr = read.filled().as_ptr(); 89 let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read)); 90 91 assert_eq!(ptr, read.filled().as_ptr()); 92 let addr = res?; 93 94 // Safety: This is guaranteed to be the number of initialized (and read) bytes due 95 // to the invariants provided by `ReadBuf::filled`. 96 unsafe { pin.rd.advance_mut(read.filled().len()) }; 97 98 addr 99 }; 100 101 pin.current_addr = Some(addr); 102 pin.is_readable = true; 103 } 104 } 105 } 106 107 impl<I, C, T> Sink<(I, SocketAddr)> for UdpFramed<C, T> 108 where 109 T: Borrow<UdpSocket>, 110 C: Encoder<I>, 111 { 112 type Error = C::Error; 113 poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>114 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 115 if !self.flushed { 116 match self.poll_flush(cx)? { 117 Poll::Ready(()) => {} 118 Poll::Pending => return Poll::Pending, 119 } 120 } 121 122 Poll::Ready(Ok(())) 123 } 124 start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error>125 fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> { 126 let (frame, out_addr) = item; 127 128 let pin = self.get_mut(); 129 130 pin.codec.encode(frame, &mut pin.wr)?; 131 pin.out_addr = out_addr; 132 pin.flushed = false; 133 134 Ok(()) 135 } 136 poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>137 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 138 if self.flushed { 139 return Poll::Ready(Ok(())); 140 } 141 142 let Self { 143 ref socket, 144 ref mut out_addr, 145 ref mut wr, 146 .. 147 } = *self; 148 149 let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?; 150 151 let wrote_all = n == self.wr.len(); 152 self.wr.clear(); 153 self.flushed = true; 154 155 let res = if wrote_all { 156 Ok(()) 157 } else { 158 Err(io::Error::new( 159 io::ErrorKind::Other, 160 "failed to write entire datagram to socket", 161 ) 162 .into()) 163 }; 164 165 Poll::Ready(res) 166 } 167 poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>168 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 169 ready!(self.poll_flush(cx))?; 170 Poll::Ready(Ok(())) 171 } 172 } 173 174 impl<C, T> UdpFramed<C, T> 175 where 176 T: Borrow<UdpSocket>, 177 { 178 /// Create a new `UdpFramed` backed by the given socket and codec. 179 /// 180 /// See struct level documentation for more details. new(socket: T, codec: C) -> UdpFramed<C, T>181 pub fn new(socket: T, codec: C) -> UdpFramed<C, T> { 182 Self { 183 socket, 184 codec, 185 out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), 186 rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), 187 wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), 188 flushed: true, 189 is_readable: false, 190 current_addr: None, 191 } 192 } 193 194 /// Returns a reference to the underlying I/O stream wrapped by `Framed`. 195 /// 196 /// # Note 197 /// 198 /// Care should be taken to not tamper with the underlying stream of data 199 /// coming in as it may corrupt the stream of frames otherwise being worked 200 /// with. get_ref(&self) -> &T201 pub fn get_ref(&self) -> &T { 202 &self.socket 203 } 204 205 /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`. 206 /// 207 /// # Note 208 /// 209 /// Care should be taken to not tamper with the underlying stream of data 210 /// coming in as it may corrupt the stream of frames otherwise being worked 211 /// with. get_mut(&mut self) -> &mut T212 pub fn get_mut(&mut self) -> &mut T { 213 &mut self.socket 214 } 215 216 /// Returns a reference to the underlying codec wrapped by 217 /// `Framed`. 218 /// 219 /// Note that care should be taken to not tamper with the underlying codec 220 /// as it may corrupt the stream of frames otherwise being worked with. codec(&self) -> &C221 pub fn codec(&self) -> &C { 222 &self.codec 223 } 224 225 /// Returns a mutable reference to the underlying codec wrapped by 226 /// `UdpFramed`. 227 /// 228 /// Note that care should be taken to not tamper with the underlying codec 229 /// as it may corrupt the stream of frames otherwise being worked with. codec_mut(&mut self) -> &mut C230 pub fn codec_mut(&mut self) -> &mut C { 231 &mut self.codec 232 } 233 234 /// Returns a reference to the read buffer. read_buffer(&self) -> &BytesMut235 pub fn read_buffer(&self) -> &BytesMut { 236 &self.rd 237 } 238 239 /// Returns a mutable reference to the read buffer. read_buffer_mut(&mut self) -> &mut BytesMut240 pub fn read_buffer_mut(&mut self) -> &mut BytesMut { 241 &mut self.rd 242 } 243 244 /// Consumes the `Framed`, returning its underlying I/O stream. into_inner(self) -> T245 pub fn into_inner(self) -> T { 246 self.socket 247 } 248 } 249