1 use std::cmp;
2 use std::fmt;
3 #[cfg(all(feature = "server", feature = "runtime"))]
4 use std::future::Future;
5 use std::io::{self, IoSlice};
6 use std::marker::Unpin;
7 use std::mem::MaybeUninit;
8 use std::pin::Pin;
9 use std::task::{Context, Poll};
10 #[cfg(all(feature = "server", feature = "runtime"))]
11 use std::time::Duration;
12 
13 use bytes::{Buf, BufMut, Bytes, BytesMut};
14 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15 #[cfg(all(feature = "server", feature = "runtime"))]
16 use tokio::time::Instant;
17 use tracing::{debug, trace};
18 
19 use super::{Http1Transaction, ParseContext, ParsedMessage};
20 use crate::common::buf::BufList;
21 
22 /// The initial buffer size allocated before trying to read from IO.
23 pub(crate) const INIT_BUFFER_SIZE: usize = 8192;
24 
25 /// The minimum value that can be set to max buffer size.
26 pub(crate) const MINIMUM_MAX_BUFFER_SIZE: usize = INIT_BUFFER_SIZE;
27 
28 /// The default maximum read buffer size. If the buffer gets this big and
29 /// a message is still not complete, a `TooLarge` error is triggered.
30 // Note: if this changes, update server::conn::Http::max_buf_size docs.
31 pub(crate) const DEFAULT_MAX_BUFFER_SIZE: usize = 8192 + 4096 * 100;
32 
33 /// The maximum number of distinct `Buf`s to hold in a list before requiring
34 /// a flush. Only affects when the buffer strategy is to queue buffers.
35 ///
36 /// Note that a flush can happen before reaching the maximum. This simply
37 /// forces a flush if the queue gets this big.
38 const MAX_BUF_LIST_BUFFERS: usize = 16;
39 
40 pub(crate) struct Buffered<T, B> {
41     flush_pipeline: bool,
42     io: T,
43     read_blocked: bool,
44     read_buf: BytesMut,
45     read_buf_strategy: ReadStrategy,
46     write_buf: WriteBuf<B>,
47 }
48 
49 impl<T, B> fmt::Debug for Buffered<T, B>
50 where
51     B: Buf,
52 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result53     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54         f.debug_struct("Buffered")
55             .field("read_buf", &self.read_buf)
56             .field("write_buf", &self.write_buf)
57             .finish()
58     }
59 }
60 
61 impl<T, B> Buffered<T, B>
62 where
63     T: AsyncRead + AsyncWrite + Unpin,
64     B: Buf,
65 {
new(io: T) -> Buffered<T, B>66     pub(crate) fn new(io: T) -> Buffered<T, B> {
67         let strategy = if io.is_write_vectored() {
68             WriteStrategy::Queue
69         } else {
70             WriteStrategy::Flatten
71         };
72         let write_buf = WriteBuf::new(strategy);
73         Buffered {
74             flush_pipeline: false,
75             io,
76             read_blocked: false,
77             read_buf: BytesMut::with_capacity(0),
78             read_buf_strategy: ReadStrategy::default(),
79             write_buf,
80         }
81     }
82 
83     #[cfg(feature = "server")]
set_flush_pipeline(&mut self, enabled: bool)84     pub(crate) fn set_flush_pipeline(&mut self, enabled: bool) {
85         debug_assert!(!self.write_buf.has_remaining());
86         self.flush_pipeline = enabled;
87         if enabled {
88             self.set_write_strategy_flatten();
89         }
90     }
91 
set_max_buf_size(&mut self, max: usize)92     pub(crate) fn set_max_buf_size(&mut self, max: usize) {
93         assert!(
94             max >= MINIMUM_MAX_BUFFER_SIZE,
95             "The max_buf_size cannot be smaller than {}.",
96             MINIMUM_MAX_BUFFER_SIZE,
97         );
98         self.read_buf_strategy = ReadStrategy::with_max(max);
99         self.write_buf.max_buf_size = max;
100     }
101 
102     #[cfg(feature = "client")]
set_read_buf_exact_size(&mut self, sz: usize)103     pub(crate) fn set_read_buf_exact_size(&mut self, sz: usize) {
104         self.read_buf_strategy = ReadStrategy::Exact(sz);
105     }
106 
set_write_strategy_flatten(&mut self)107     pub(crate) fn set_write_strategy_flatten(&mut self) {
108         // this should always be called only at construction time,
109         // so this assert is here to catch myself
110         debug_assert!(self.write_buf.queue.bufs_cnt() == 0);
111         self.write_buf.set_strategy(WriteStrategy::Flatten);
112     }
113 
set_write_strategy_queue(&mut self)114     pub(crate) fn set_write_strategy_queue(&mut self) {
115         // this should always be called only at construction time,
116         // so this assert is here to catch myself
117         debug_assert!(self.write_buf.queue.bufs_cnt() == 0);
118         self.write_buf.set_strategy(WriteStrategy::Queue);
119     }
120 
read_buf(&self) -> &[u8]121     pub(crate) fn read_buf(&self) -> &[u8] {
122         self.read_buf.as_ref()
123     }
124 
125     #[cfg(test)]
126     #[cfg(feature = "nightly")]
read_buf_mut(&mut self) -> &mut BytesMut127     pub(super) fn read_buf_mut(&mut self) -> &mut BytesMut {
128         &mut self.read_buf
129     }
130 
131     /// Return the "allocated" available space, not the potential space
132     /// that could be allocated in the future.
read_buf_remaining_mut(&self) -> usize133     fn read_buf_remaining_mut(&self) -> usize {
134         self.read_buf.capacity() - self.read_buf.len()
135     }
136 
137     /// Return whether we can append to the headers buffer.
138     ///
139     /// Reasons we can't:
140     /// - The write buf is in queue mode, and some of the past body is still
141     ///   needing to be flushed.
can_headers_buf(&self) -> bool142     pub(crate) fn can_headers_buf(&self) -> bool {
143         !self.write_buf.queue.has_remaining()
144     }
145 
headers_buf(&mut self) -> &mut Vec<u8>146     pub(crate) fn headers_buf(&mut self) -> &mut Vec<u8> {
147         let buf = self.write_buf.headers_mut();
148         &mut buf.bytes
149     }
150 
write_buf(&mut self) -> &mut WriteBuf<B>151     pub(super) fn write_buf(&mut self) -> &mut WriteBuf<B> {
152         &mut self.write_buf
153     }
154 
buffer<BB: Buf + Into<B>>(&mut self, buf: BB)155     pub(crate) fn buffer<BB: Buf + Into<B>>(&mut self, buf: BB) {
156         self.write_buf.buffer(buf)
157     }
158 
can_buffer(&self) -> bool159     pub(crate) fn can_buffer(&self) -> bool {
160         self.flush_pipeline || self.write_buf.can_buffer()
161     }
162 
consume_leading_lines(&mut self)163     pub(crate) fn consume_leading_lines(&mut self) {
164         if !self.read_buf.is_empty() {
165             let mut i = 0;
166             while i < self.read_buf.len() {
167                 match self.read_buf[i] {
168                     b'\r' | b'\n' => i += 1,
169                     _ => break,
170                 }
171             }
172             self.read_buf.advance(i);
173         }
174     }
175 
parse<S>( &mut self, cx: &mut Context<'_>, parse_ctx: ParseContext<'_>, ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>> where S: Http1Transaction,176     pub(super) fn parse<S>(
177         &mut self,
178         cx: &mut Context<'_>,
179         parse_ctx: ParseContext<'_>,
180     ) -> Poll<crate::Result<ParsedMessage<S::Incoming>>>
181     where
182         S: Http1Transaction,
183     {
184         loop {
185             match super::role::parse_headers::<S>(
186                 &mut self.read_buf,
187                 ParseContext {
188                     cached_headers: parse_ctx.cached_headers,
189                     req_method: parse_ctx.req_method,
190                     h1_parser_config: parse_ctx.h1_parser_config.clone(),
191                     #[cfg(all(feature = "server", feature = "runtime"))]
192                     h1_header_read_timeout: parse_ctx.h1_header_read_timeout,
193                     #[cfg(all(feature = "server", feature = "runtime"))]
194                     h1_header_read_timeout_fut: parse_ctx.h1_header_read_timeout_fut,
195                     #[cfg(all(feature = "server", feature = "runtime"))]
196                     h1_header_read_timeout_running: parse_ctx.h1_header_read_timeout_running,
197                     preserve_header_case: parse_ctx.preserve_header_case,
198                     #[cfg(feature = "ffi")]
199                     preserve_header_order: parse_ctx.preserve_header_order,
200                     h09_responses: parse_ctx.h09_responses,
201                     #[cfg(feature = "ffi")]
202                     on_informational: parse_ctx.on_informational,
203                     #[cfg(feature = "ffi")]
204                     raw_headers: parse_ctx.raw_headers,
205                 },
206             )? {
207                 Some(msg) => {
208                     debug!("parsed {} headers", msg.head.headers.len());
209 
210                     #[cfg(all(feature = "server", feature = "runtime"))]
211                     {
212                         *parse_ctx.h1_header_read_timeout_running = false;
213 
214                         if let Some(h1_header_read_timeout_fut) =
215                             parse_ctx.h1_header_read_timeout_fut
216                         {
217                             // Reset the timer in order to avoid woken up when the timeout finishes
218                             h1_header_read_timeout_fut
219                                 .as_mut()
220                                 .reset(Instant::now() + Duration::from_secs(30 * 24 * 60 * 60));
221                         }
222                     }
223                     return Poll::Ready(Ok(msg));
224                 }
225                 None => {
226                     let max = self.read_buf_strategy.max();
227                     if self.read_buf.len() >= max {
228                         debug!("max_buf_size ({}) reached, closing", max);
229                         return Poll::Ready(Err(crate::Error::new_too_large()));
230                     }
231 
232                     #[cfg(all(feature = "server", feature = "runtime"))]
233                     if *parse_ctx.h1_header_read_timeout_running {
234                         if let Some(h1_header_read_timeout_fut) =
235                             parse_ctx.h1_header_read_timeout_fut
236                         {
237                             if Pin::new(h1_header_read_timeout_fut).poll(cx).is_ready() {
238                                 *parse_ctx.h1_header_read_timeout_running = false;
239 
240                                 tracing::warn!("read header from client timeout");
241                                 return Poll::Ready(Err(crate::Error::new_header_timeout()));
242                             }
243                         }
244                     }
245                 }
246             }
247             if ready!(self.poll_read_from_io(cx)).map_err(crate::Error::new_io)? == 0 {
248                 trace!("parse eof");
249                 return Poll::Ready(Err(crate::Error::new_incomplete()));
250             }
251         }
252     }
253 
poll_read_from_io(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>>254     pub(crate) fn poll_read_from_io(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
255         self.read_blocked = false;
256         let next = self.read_buf_strategy.next();
257         if self.read_buf_remaining_mut() < next {
258             self.read_buf.reserve(next);
259         }
260 
261         let dst = self.read_buf.chunk_mut();
262         let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
263         let mut buf = ReadBuf::uninit(dst);
264         match Pin::new(&mut self.io).poll_read(cx, &mut buf) {
265             Poll::Ready(Ok(_)) => {
266                 let n = buf.filled().len();
267                 trace!("received {} bytes", n);
268                 unsafe {
269                     // Safety: we just read that many bytes into the
270                     // uninitialized part of the buffer, so this is okay.
271                     // @tokio pls give me back `poll_read_buf` thanks
272                     self.read_buf.advance_mut(n);
273                 }
274                 self.read_buf_strategy.record(n);
275                 Poll::Ready(Ok(n))
276             }
277             Poll::Pending => {
278                 self.read_blocked = true;
279                 Poll::Pending
280             }
281             Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
282         }
283     }
284 
into_inner(self) -> (T, Bytes)285     pub(crate) fn into_inner(self) -> (T, Bytes) {
286         (self.io, self.read_buf.freeze())
287     }
288 
io_mut(&mut self) -> &mut T289     pub(crate) fn io_mut(&mut self) -> &mut T {
290         &mut self.io
291     }
292 
is_read_blocked(&self) -> bool293     pub(crate) fn is_read_blocked(&self) -> bool {
294         self.read_blocked
295     }
296 
poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>297     pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
298         if self.flush_pipeline && !self.read_buf.is_empty() {
299             Poll::Ready(Ok(()))
300         } else if self.write_buf.remaining() == 0 {
301             Pin::new(&mut self.io).poll_flush(cx)
302         } else {
303             if let WriteStrategy::Flatten = self.write_buf.strategy {
304                 return self.poll_flush_flattened(cx);
305             }
306 
307             const MAX_WRITEV_BUFS: usize = 64;
308             loop {
309                 let n = {
310                     let mut iovs = [IoSlice::new(&[]); MAX_WRITEV_BUFS];
311                     let len = self.write_buf.chunks_vectored(&mut iovs);
312                     ready!(Pin::new(&mut self.io).poll_write_vectored(cx, &iovs[..len]))?
313                 };
314                 // TODO(eliza): we have to do this manually because
315                 // `poll_write_buf` doesn't exist in Tokio 0.3 yet...when
316                 // `poll_write_buf` comes back, the manual advance will need to leave!
317                 self.write_buf.advance(n);
318                 debug!("flushed {} bytes", n);
319                 if self.write_buf.remaining() == 0 {
320                     break;
321                 } else if n == 0 {
322                     trace!(
323                         "write returned zero, but {} bytes remaining",
324                         self.write_buf.remaining()
325                     );
326                     return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
327                 }
328             }
329             Pin::new(&mut self.io).poll_flush(cx)
330         }
331     }
332 
333     /// Specialized version of `flush` when strategy is Flatten.
334     ///
335     /// Since all buffered bytes are flattened into the single headers buffer,
336     /// that skips some bookkeeping around using multiple buffers.
poll_flush_flattened(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>337     fn poll_flush_flattened(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
338         loop {
339             let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.chunk()))?;
340             debug!("flushed {} bytes", n);
341             self.write_buf.headers.advance(n);
342             if self.write_buf.headers.remaining() == 0 {
343                 self.write_buf.headers.reset();
344                 break;
345             } else if n == 0 {
346                 trace!(
347                     "write returned zero, but {} bytes remaining",
348                     self.write_buf.remaining()
349                 );
350                 return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
351             }
352         }
353         Pin::new(&mut self.io).poll_flush(cx)
354     }
355 
356     #[cfg(test)]
flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a357     fn flush<'a>(&'a mut self) -> impl std::future::Future<Output = io::Result<()>> + 'a {
358         futures_util::future::poll_fn(move |cx| self.poll_flush(cx))
359     }
360 }
361 
362 // The `B` is a `Buf`, we never project a pin to it
363 impl<T: Unpin, B> Unpin for Buffered<T, B> {}
364 
365 // TODO: This trait is old... at least rename to PollBytes or something...
366 pub(crate) trait MemRead {
read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>>367     fn read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>>;
368 }
369 
370 impl<T, B> MemRead for Buffered<T, B>
371 where
372     T: AsyncRead + AsyncWrite + Unpin,
373     B: Buf,
374 {
read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>>375     fn read_mem(&mut self, cx: &mut Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
376         if !self.read_buf.is_empty() {
377             let n = std::cmp::min(len, self.read_buf.len());
378             Poll::Ready(Ok(self.read_buf.split_to(n).freeze()))
379         } else {
380             let n = ready!(self.poll_read_from_io(cx))?;
381             Poll::Ready(Ok(self.read_buf.split_to(::std::cmp::min(len, n)).freeze()))
382         }
383     }
384 }
385 
386 #[derive(Clone, Copy, Debug)]
387 enum ReadStrategy {
388     Adaptive {
389         decrease_now: bool,
390         next: usize,
391         max: usize,
392     },
393     #[cfg(feature = "client")]
394     Exact(usize),
395 }
396 
397 impl ReadStrategy {
with_max(max: usize) -> ReadStrategy398     fn with_max(max: usize) -> ReadStrategy {
399         ReadStrategy::Adaptive {
400             decrease_now: false,
401             next: INIT_BUFFER_SIZE,
402             max,
403         }
404     }
405 
next(&self) -> usize406     fn next(&self) -> usize {
407         match *self {
408             ReadStrategy::Adaptive { next, .. } => next,
409             #[cfg(feature = "client")]
410             ReadStrategy::Exact(exact) => exact,
411         }
412     }
413 
max(&self) -> usize414     fn max(&self) -> usize {
415         match *self {
416             ReadStrategy::Adaptive { max, .. } => max,
417             #[cfg(feature = "client")]
418             ReadStrategy::Exact(exact) => exact,
419         }
420     }
421 
record(&mut self, bytes_read: usize)422     fn record(&mut self, bytes_read: usize) {
423         match *self {
424             ReadStrategy::Adaptive {
425                 ref mut decrease_now,
426                 ref mut next,
427                 max,
428                 ..
429             } => {
430                 if bytes_read >= *next {
431                     *next = cmp::min(incr_power_of_two(*next), max);
432                     *decrease_now = false;
433                 } else {
434                     let decr_to = prev_power_of_two(*next);
435                     if bytes_read < decr_to {
436                         if *decrease_now {
437                             *next = cmp::max(decr_to, INIT_BUFFER_SIZE);
438                             *decrease_now = false;
439                         } else {
440                             // Decreasing is a two "record" process.
441                             *decrease_now = true;
442                         }
443                     } else {
444                         // A read within the current range should cancel
445                         // a potential decrease, since we just saw proof
446                         // that we still need this size.
447                         *decrease_now = false;
448                     }
449                 }
450             }
451             #[cfg(feature = "client")]
452             ReadStrategy::Exact(_) => (),
453         }
454     }
455 }
456 
incr_power_of_two(n: usize) -> usize457 fn incr_power_of_two(n: usize) -> usize {
458     n.saturating_mul(2)
459 }
460 
prev_power_of_two(n: usize) -> usize461 fn prev_power_of_two(n: usize) -> usize {
462     // Only way this shift can underflow is if n is less than 4.
463     // (Which would means `usize::MAX >> 64` and underflowed!)
464     debug_assert!(n >= 4);
465     (::std::usize::MAX >> (n.leading_zeros() + 2)) + 1
466 }
467 
468 impl Default for ReadStrategy {
default() -> ReadStrategy469     fn default() -> ReadStrategy {
470         ReadStrategy::with_max(DEFAULT_MAX_BUFFER_SIZE)
471     }
472 }
473 
474 #[derive(Clone)]
475 pub(crate) struct Cursor<T> {
476     bytes: T,
477     pos: usize,
478 }
479 
480 impl<T: AsRef<[u8]>> Cursor<T> {
481     #[inline]
new(bytes: T) -> Cursor<T>482     pub(crate) fn new(bytes: T) -> Cursor<T> {
483         Cursor { bytes, pos: 0 }
484     }
485 }
486 
487 impl Cursor<Vec<u8>> {
488     /// If we've advanced the position a bit in this cursor, and wish to
489     /// extend the underlying vector, we may wish to unshift the "read" bytes
490     /// off, and move everything else over.
maybe_unshift(&mut self, additional: usize)491     fn maybe_unshift(&mut self, additional: usize) {
492         if self.pos == 0 {
493             // nothing to do
494             return;
495         }
496 
497         if self.bytes.capacity() - self.bytes.len() >= additional {
498             // there's room!
499             return;
500         }
501 
502         self.bytes.drain(0..self.pos);
503         self.pos = 0;
504     }
505 
reset(&mut self)506     fn reset(&mut self) {
507         self.pos = 0;
508         self.bytes.clear();
509     }
510 }
511 
512 impl<T: AsRef<[u8]>> fmt::Debug for Cursor<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result513     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514         f.debug_struct("Cursor")
515             .field("pos", &self.pos)
516             .field("len", &self.bytes.as_ref().len())
517             .finish()
518     }
519 }
520 
521 impl<T: AsRef<[u8]>> Buf for Cursor<T> {
522     #[inline]
remaining(&self) -> usize523     fn remaining(&self) -> usize {
524         self.bytes.as_ref().len() - self.pos
525     }
526 
527     #[inline]
chunk(&self) -> &[u8]528     fn chunk(&self) -> &[u8] {
529         &self.bytes.as_ref()[self.pos..]
530     }
531 
532     #[inline]
advance(&mut self, cnt: usize)533     fn advance(&mut self, cnt: usize) {
534         debug_assert!(self.pos + cnt <= self.bytes.as_ref().len());
535         self.pos += cnt;
536     }
537 }
538 
539 // an internal buffer to collect writes before flushes
540 pub(super) struct WriteBuf<B> {
541     /// Re-usable buffer that holds message headers
542     headers: Cursor<Vec<u8>>,
543     max_buf_size: usize,
544     /// Deque of user buffers if strategy is Queue
545     queue: BufList<B>,
546     strategy: WriteStrategy,
547 }
548 
549 impl<B: Buf> WriteBuf<B> {
new(strategy: WriteStrategy) -> WriteBuf<B>550     fn new(strategy: WriteStrategy) -> WriteBuf<B> {
551         WriteBuf {
552             headers: Cursor::new(Vec::with_capacity(INIT_BUFFER_SIZE)),
553             max_buf_size: DEFAULT_MAX_BUFFER_SIZE,
554             queue: BufList::new(),
555             strategy,
556         }
557     }
558 }
559 
560 impl<B> WriteBuf<B>
561 where
562     B: Buf,
563 {
set_strategy(&mut self, strategy: WriteStrategy)564     fn set_strategy(&mut self, strategy: WriteStrategy) {
565         self.strategy = strategy;
566     }
567 
buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB)568     pub(super) fn buffer<BB: Buf + Into<B>>(&mut self, mut buf: BB) {
569         debug_assert!(buf.has_remaining());
570         match self.strategy {
571             WriteStrategy::Flatten => {
572                 let head = self.headers_mut();
573 
574                 head.maybe_unshift(buf.remaining());
575                 trace!(
576                     self.len = head.remaining(),
577                     buf.len = buf.remaining(),
578                     "buffer.flatten"
579                 );
580                 //perf: This is a little faster than <Vec as BufMut>>::put,
581                 //but accomplishes the same result.
582                 loop {
583                     let adv = {
584                         let slice = buf.chunk();
585                         if slice.is_empty() {
586                             return;
587                         }
588                         head.bytes.extend_from_slice(slice);
589                         slice.len()
590                     };
591                     buf.advance(adv);
592                 }
593             }
594             WriteStrategy::Queue => {
595                 trace!(
596                     self.len = self.remaining(),
597                     buf.len = buf.remaining(),
598                     "buffer.queue"
599                 );
600                 self.queue.push(buf.into());
601             }
602         }
603     }
604 
can_buffer(&self) -> bool605     fn can_buffer(&self) -> bool {
606         match self.strategy {
607             WriteStrategy::Flatten => self.remaining() < self.max_buf_size,
608             WriteStrategy::Queue => {
609                 self.queue.bufs_cnt() < MAX_BUF_LIST_BUFFERS && self.remaining() < self.max_buf_size
610             }
611         }
612     }
613 
headers_mut(&mut self) -> &mut Cursor<Vec<u8>>614     fn headers_mut(&mut self) -> &mut Cursor<Vec<u8>> {
615         debug_assert!(!self.queue.has_remaining());
616         &mut self.headers
617     }
618 }
619 
620 impl<B: Buf> fmt::Debug for WriteBuf<B> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result621     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
622         f.debug_struct("WriteBuf")
623             .field("remaining", &self.remaining())
624             .field("strategy", &self.strategy)
625             .finish()
626     }
627 }
628 
629 impl<B: Buf> Buf for WriteBuf<B> {
630     #[inline]
remaining(&self) -> usize631     fn remaining(&self) -> usize {
632         self.headers.remaining() + self.queue.remaining()
633     }
634 
635     #[inline]
chunk(&self) -> &[u8]636     fn chunk(&self) -> &[u8] {
637         let headers = self.headers.chunk();
638         if !headers.is_empty() {
639             headers
640         } else {
641             self.queue.chunk()
642         }
643     }
644 
645     #[inline]
advance(&mut self, cnt: usize)646     fn advance(&mut self, cnt: usize) {
647         let hrem = self.headers.remaining();
648 
649         match hrem.cmp(&cnt) {
650             cmp::Ordering::Equal => self.headers.reset(),
651             cmp::Ordering::Greater => self.headers.advance(cnt),
652             cmp::Ordering::Less => {
653                 let qcnt = cnt - hrem;
654                 self.headers.reset();
655                 self.queue.advance(qcnt);
656             }
657         }
658     }
659 
660     #[inline]
chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize661     fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
662         let n = self.headers.chunks_vectored(dst);
663         self.queue.chunks_vectored(&mut dst[n..]) + n
664     }
665 }
666 
667 #[derive(Debug)]
668 enum WriteStrategy {
669     Flatten,
670     Queue,
671 }
672 
673 #[cfg(test)]
674 mod tests {
675     use super::*;
676     use std::time::Duration;
677 
678     use tokio_test::io::Builder as Mock;
679 
680     // #[cfg(feature = "nightly")]
681     // use test::Bencher;
682 
683     /*
684     impl<T: Read> MemRead for AsyncIo<T> {
685         fn read_mem(&mut self, len: usize) -> Poll<Bytes, io::Error> {
686             let mut v = vec![0; len];
687             let n = try_nb!(self.read(v.as_mut_slice()));
688             Ok(Async::Ready(BytesMut::from(&v[..n]).freeze()))
689         }
690     }
691     */
692 
693     #[tokio::test]
694     #[ignore]
iobuf_write_empty_slice()695     async fn iobuf_write_empty_slice() {
696         // TODO(eliza): can i have writev back pls T_T
697         // // First, let's just check that the Mock would normally return an
698         // // error on an unexpected write, even if the buffer is empty...
699         // let mut mock = Mock::new().build();
700         // futures_util::future::poll_fn(|cx| {
701         //     Pin::new(&mut mock).poll_write_buf(cx, &mut Cursor::new(&[]))
702         // })
703         // .await
704         // .expect_err("should be a broken pipe");
705 
706         // // underlying io will return the logic error upon write,
707         // // so we are testing that the io_buf does not trigger a write
708         // // when there is nothing to flush
709         // let mock = Mock::new().build();
710         // let mut io_buf = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
711         // io_buf.flush().await.expect("should short-circuit flush");
712     }
713 
714     #[tokio::test]
parse_reads_until_blocked()715     async fn parse_reads_until_blocked() {
716         use crate::proto::h1::ClientTransaction;
717 
718         let _ = pretty_env_logger::try_init();
719         let mock = Mock::new()
720             // Split over multiple reads will read all of it
721             .read(b"HTTP/1.1 200 OK\r\n")
722             .read(b"Server: hyper\r\n")
723             // missing last line ending
724             .wait(Duration::from_secs(1))
725             .build();
726 
727         let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
728 
729         // We expect a `parse` to be not ready, and so can't await it directly.
730         // Rather, this `poll_fn` will wrap the `Poll` result.
731         futures_util::future::poll_fn(|cx| {
732             let parse_ctx = ParseContext {
733                 cached_headers: &mut None,
734                 req_method: &mut None,
735                 h1_parser_config: Default::default(),
736                 #[cfg(feature = "runtime")]
737                 h1_header_read_timeout: None,
738                 #[cfg(feature = "runtime")]
739                 h1_header_read_timeout_fut: &mut None,
740                 #[cfg(feature = "runtime")]
741                 h1_header_read_timeout_running: &mut false,
742                 preserve_header_case: false,
743                 #[cfg(feature = "ffi")]
744                 preserve_header_order: false,
745                 h09_responses: false,
746                 #[cfg(feature = "ffi")]
747                 on_informational: &mut None,
748                 #[cfg(feature = "ffi")]
749                 raw_headers: false,
750             };
751             assert!(buffered
752                 .parse::<ClientTransaction>(cx, parse_ctx)
753                 .is_pending());
754             Poll::Ready(())
755         })
756         .await;
757 
758         assert_eq!(
759             buffered.read_buf,
760             b"HTTP/1.1 200 OK\r\nServer: hyper\r\n"[..]
761         );
762     }
763 
764     #[test]
read_strategy_adaptive_increments()765     fn read_strategy_adaptive_increments() {
766         let mut strategy = ReadStrategy::default();
767         assert_eq!(strategy.next(), 8192);
768 
769         // Grows if record == next
770         strategy.record(8192);
771         assert_eq!(strategy.next(), 16384);
772 
773         strategy.record(16384);
774         assert_eq!(strategy.next(), 32768);
775 
776         // Enormous records still increment at same rate
777         strategy.record(::std::usize::MAX);
778         assert_eq!(strategy.next(), 65536);
779 
780         let max = strategy.max();
781         while strategy.next() < max {
782             strategy.record(max);
783         }
784 
785         assert_eq!(strategy.next(), max, "never goes over max");
786         strategy.record(max + 1);
787         assert_eq!(strategy.next(), max, "never goes over max");
788     }
789 
790     #[test]
read_strategy_adaptive_decrements()791     fn read_strategy_adaptive_decrements() {
792         let mut strategy = ReadStrategy::default();
793         strategy.record(8192);
794         assert_eq!(strategy.next(), 16384);
795 
796         strategy.record(1);
797         assert_eq!(
798             strategy.next(),
799             16384,
800             "first smaller record doesn't decrement yet"
801         );
802         strategy.record(8192);
803         assert_eq!(strategy.next(), 16384, "record was with range");
804 
805         strategy.record(1);
806         assert_eq!(
807             strategy.next(),
808             16384,
809             "in-range record should make this the 'first' again"
810         );
811 
812         strategy.record(1);
813         assert_eq!(strategy.next(), 8192, "second smaller record decrements");
814 
815         strategy.record(1);
816         assert_eq!(strategy.next(), 8192, "first doesn't decrement");
817         strategy.record(1);
818         assert_eq!(strategy.next(), 8192, "doesn't decrement under minimum");
819     }
820 
821     #[test]
read_strategy_adaptive_stays_the_same()822     fn read_strategy_adaptive_stays_the_same() {
823         let mut strategy = ReadStrategy::default();
824         strategy.record(8192);
825         assert_eq!(strategy.next(), 16384);
826 
827         strategy.record(8193);
828         assert_eq!(
829             strategy.next(),
830             16384,
831             "first smaller record doesn't decrement yet"
832         );
833 
834         strategy.record(8193);
835         assert_eq!(
836             strategy.next(),
837             16384,
838             "with current step does not decrement"
839         );
840     }
841 
842     #[test]
read_strategy_adaptive_max_fuzz()843     fn read_strategy_adaptive_max_fuzz() {
844         fn fuzz(max: usize) {
845             let mut strategy = ReadStrategy::with_max(max);
846             while strategy.next() < max {
847                 strategy.record(::std::usize::MAX);
848             }
849             let mut next = strategy.next();
850             while next > 8192 {
851                 strategy.record(1);
852                 strategy.record(1);
853                 next = strategy.next();
854                 assert!(
855                     next.is_power_of_two(),
856                     "decrement should be powers of two: {} (max = {})",
857                     next,
858                     max,
859                 );
860             }
861         }
862 
863         let mut max = 8192;
864         while max < std::usize::MAX {
865             fuzz(max);
866             max = (max / 2).saturating_mul(3);
867         }
868         fuzz(::std::usize::MAX);
869     }
870 
871     #[test]
872     #[should_panic]
873     #[cfg(debug_assertions)] // needs to trigger a debug_assert
write_buf_requires_non_empty_bufs()874     fn write_buf_requires_non_empty_bufs() {
875         let mock = Mock::new().build();
876         let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
877 
878         buffered.buffer(Cursor::new(Vec::new()));
879     }
880 
881     /*
882     TODO: needs tokio_test::io to allow configure write_buf calls
883     #[test]
884     fn write_buf_queue() {
885         let _ = pretty_env_logger::try_init();
886 
887         let mock = AsyncIo::new_buf(vec![], 1024);
888         let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
889 
890 
891         buffered.headers_buf().extend(b"hello ");
892         buffered.buffer(Cursor::new(b"world, ".to_vec()));
893         buffered.buffer(Cursor::new(b"it's ".to_vec()));
894         buffered.buffer(Cursor::new(b"hyper!".to_vec()));
895         assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3);
896         buffered.flush().unwrap();
897 
898         assert_eq!(buffered.io, b"hello world, it's hyper!");
899         assert_eq!(buffered.io.num_writes(), 1);
900         assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0);
901     }
902     */
903 
904     #[tokio::test]
write_buf_flatten()905     async fn write_buf_flatten() {
906         let _ = pretty_env_logger::try_init();
907 
908         let mock = Mock::new().write(b"hello world, it's hyper!").build();
909 
910         let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
911         buffered.write_buf.set_strategy(WriteStrategy::Flatten);
912 
913         buffered.headers_buf().extend(b"hello ");
914         buffered.buffer(Cursor::new(b"world, ".to_vec()));
915         buffered.buffer(Cursor::new(b"it's ".to_vec()));
916         buffered.buffer(Cursor::new(b"hyper!".to_vec()));
917         assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0);
918 
919         buffered.flush().await.expect("flush");
920     }
921 
922     #[test]
write_buf_flatten_partially_flushed()923     fn write_buf_flatten_partially_flushed() {
924         let _ = pretty_env_logger::try_init();
925 
926         let b = |s: &str| Cursor::new(s.as_bytes().to_vec());
927 
928         let mut write_buf = WriteBuf::<Cursor<Vec<u8>>>::new(WriteStrategy::Flatten);
929 
930         write_buf.buffer(b("hello "));
931         write_buf.buffer(b("world, "));
932 
933         assert_eq!(write_buf.chunk(), b"hello world, ");
934 
935         // advance most of the way, but not all
936         write_buf.advance(11);
937 
938         assert_eq!(write_buf.chunk(), b", ");
939         assert_eq!(write_buf.headers.pos, 11);
940         assert_eq!(write_buf.headers.bytes.capacity(), INIT_BUFFER_SIZE);
941 
942         // there's still room in the headers buffer, so just push on the end
943         write_buf.buffer(b("it's hyper!"));
944 
945         assert_eq!(write_buf.chunk(), b", it's hyper!");
946         assert_eq!(write_buf.headers.pos, 11);
947 
948         let rem1 = write_buf.remaining();
949         let cap = write_buf.headers.bytes.capacity();
950 
951         // but when this would go over capacity, don't copy the old bytes
952         write_buf.buffer(Cursor::new(vec![b'X'; cap]));
953         assert_eq!(write_buf.remaining(), cap + rem1);
954         assert_eq!(write_buf.headers.pos, 0);
955     }
956 
957     #[tokio::test]
write_buf_queue_disable_auto()958     async fn write_buf_queue_disable_auto() {
959         let _ = pretty_env_logger::try_init();
960 
961         let mock = Mock::new()
962             .write(b"hello ")
963             .write(b"world, ")
964             .write(b"it's ")
965             .write(b"hyper!")
966             .build();
967 
968         let mut buffered = Buffered::<_, Cursor<Vec<u8>>>::new(mock);
969         buffered.write_buf.set_strategy(WriteStrategy::Queue);
970 
971         // we have 4 buffers, and vec IO disabled, but explicitly said
972         // don't try to auto detect (via setting strategy above)
973 
974         buffered.headers_buf().extend(b"hello ");
975         buffered.buffer(Cursor::new(b"world, ".to_vec()));
976         buffered.buffer(Cursor::new(b"it's ".to_vec()));
977         buffered.buffer(Cursor::new(b"hyper!".to_vec()));
978         assert_eq!(buffered.write_buf.queue.bufs_cnt(), 3);
979 
980         buffered.flush().await.expect("flush");
981 
982         assert_eq!(buffered.write_buf.queue.bufs_cnt(), 0);
983     }
984 
985     // #[cfg(feature = "nightly")]
986     // #[bench]
987     // fn bench_write_buf_flatten_buffer_chunk(b: &mut Bencher) {
988     //     let s = "Hello, World!";
989     //     b.bytes = s.len() as u64;
990 
991     //     let mut write_buf = WriteBuf::<bytes::Bytes>::new();
992     //     write_buf.set_strategy(WriteStrategy::Flatten);
993     //     b.iter(|| {
994     //         let chunk = bytes::Bytes::from(s);
995     //         write_buf.buffer(chunk);
996     //         ::test::black_box(&write_buf);
997     //         write_buf.headers.bytes.clear();
998     //     })
999     // }
1000 }
1001