1 use super::{util, StreamDependency, StreamId};
2 use crate::ext::Protocol;
3 use crate::frame::{Error, Frame, Head, Kind};
4 use crate::hpack::{self, BytesStr};
5 
6 use http::header::{self, HeaderName, HeaderValue};
7 use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8 
9 use bytes::{BufMut, Bytes, BytesMut};
10 
11 use std::fmt;
12 use std::io::Cursor;
13 
14 type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
15 
16 /// Header frame
17 ///
18 /// This could be either a request or a response.
19 #[derive(Eq, PartialEq)]
20 pub struct Headers {
21     /// The ID of the stream with which this frame is associated.
22     stream_id: StreamId,
23 
24     /// The stream dependency information, if any.
25     stream_dep: Option<StreamDependency>,
26 
27     /// The header block fragment
28     header_block: HeaderBlock,
29 
30     /// The associated flags
31     flags: HeadersFlag,
32 }
33 
34 #[derive(Copy, Clone, Eq, PartialEq)]
35 pub struct HeadersFlag(u8);
36 
37 #[derive(Eq, PartialEq)]
38 pub struct PushPromise {
39     /// The ID of the stream with which this frame is associated.
40     stream_id: StreamId,
41 
42     /// The ID of the stream being reserved by this PushPromise.
43     promised_id: StreamId,
44 
45     /// The header block fragment
46     header_block: HeaderBlock,
47 
48     /// The associated flags
49     flags: PushPromiseFlag,
50 }
51 
52 #[derive(Copy, Clone, Eq, PartialEq)]
53 pub struct PushPromiseFlag(u8);
54 
55 #[derive(Debug)]
56 pub struct Continuation {
57     /// Stream ID of continuation frame
58     stream_id: StreamId,
59 
60     header_block: EncodingHeaderBlock,
61 }
62 
63 // TODO: These fields shouldn't be `pub`
64 #[derive(Debug, Default, Eq, PartialEq)]
65 pub struct Pseudo {
66     // Request
67     pub method: Option<Method>,
68     pub scheme: Option<BytesStr>,
69     pub authority: Option<BytesStr>,
70     pub path: Option<BytesStr>,
71     pub protocol: Option<Protocol>,
72 
73     // Response
74     pub status: Option<StatusCode>,
75 }
76 
77 #[derive(Debug)]
78 pub struct Iter {
79     /// Pseudo headers
80     pseudo: Option<Pseudo>,
81 
82     /// Header fields
83     fields: header::IntoIter<HeaderValue>,
84 }
85 
86 #[derive(Debug, PartialEq, Eq)]
87 struct HeaderBlock {
88     /// The decoded header fields
89     fields: HeaderMap,
90 
91     /// Precomputed size of all of our header fields, for perf reasons
92     field_size: usize,
93 
94     /// Set to true if decoding went over the max header list size.
95     is_over_size: bool,
96 
97     /// Pseudo headers, these are broken out as they must be sent as part of the
98     /// headers frame.
99     pseudo: Pseudo,
100 }
101 
102 #[derive(Debug)]
103 struct EncodingHeaderBlock {
104     hpack: Bytes,
105 }
106 
107 const END_STREAM: u8 = 0x1;
108 const END_HEADERS: u8 = 0x4;
109 const PADDED: u8 = 0x8;
110 const PRIORITY: u8 = 0x20;
111 const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
112 
113 // ===== impl Headers =====
114 
115 impl Headers {
116     /// Create a new HEADERS frame
new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self117     pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
118         Headers {
119             stream_id,
120             stream_dep: None,
121             header_block: HeaderBlock {
122                 field_size: calculate_headermap_size(&fields),
123                 fields,
124                 is_over_size: false,
125                 pseudo,
126             },
127             flags: HeadersFlag::default(),
128         }
129     }
130 
trailers(stream_id: StreamId, fields: HeaderMap) -> Self131     pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
132         let mut flags = HeadersFlag::default();
133         flags.set_end_stream();
134 
135         Headers {
136             stream_id,
137             stream_dep: None,
138             header_block: HeaderBlock {
139                 field_size: calculate_headermap_size(&fields),
140                 fields,
141                 is_over_size: false,
142                 pseudo: Pseudo::default(),
143             },
144             flags,
145         }
146     }
147 
148     /// Loads the header frame but doesn't actually do HPACK decoding.
149     ///
150     /// HPACK decoding is done in the `load_hpack` step.
load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error>151     pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
152         let flags = HeadersFlag(head.flag());
153         let mut pad = 0;
154 
155         tracing::trace!("loading headers; flags={:?}", flags);
156 
157         if head.stream_id().is_zero() {
158             return Err(Error::InvalidStreamId);
159         }
160 
161         // Read the padding length
162         if flags.is_padded() {
163             if src.is_empty() {
164                 return Err(Error::MalformedMessage);
165             }
166             pad = src[0] as usize;
167 
168             // Drop the padding
169             let _ = src.split_to(1);
170         }
171 
172         // Read the stream dependency
173         let stream_dep = if flags.is_priority() {
174             if src.len() < 5 {
175                 return Err(Error::MalformedMessage);
176             }
177             let stream_dep = StreamDependency::load(&src[..5])?;
178 
179             if stream_dep.dependency_id() == head.stream_id() {
180                 return Err(Error::InvalidDependencyId);
181             }
182 
183             // Drop the next 5 bytes
184             let _ = src.split_to(5);
185 
186             Some(stream_dep)
187         } else {
188             None
189         };
190 
191         if pad > 0 {
192             if pad > src.len() {
193                 return Err(Error::TooMuchPadding);
194             }
195 
196             let len = src.len() - pad;
197             src.truncate(len);
198         }
199 
200         let headers = Headers {
201             stream_id: head.stream_id(),
202             stream_dep,
203             header_block: HeaderBlock {
204                 fields: HeaderMap::new(),
205                 field_size: 0,
206                 is_over_size: false,
207                 pseudo: Pseudo::default(),
208             },
209             flags,
210         };
211 
212         Ok((headers, src))
213     }
214 
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>215     pub fn load_hpack(
216         &mut self,
217         src: &mut BytesMut,
218         max_header_list_size: usize,
219         decoder: &mut hpack::Decoder,
220     ) -> Result<(), Error> {
221         self.header_block.load(src, max_header_list_size, decoder)
222     }
223 
stream_id(&self) -> StreamId224     pub fn stream_id(&self) -> StreamId {
225         self.stream_id
226     }
227 
is_end_headers(&self) -> bool228     pub fn is_end_headers(&self) -> bool {
229         self.flags.is_end_headers()
230     }
231 
set_end_headers(&mut self)232     pub fn set_end_headers(&mut self) {
233         self.flags.set_end_headers();
234     }
235 
is_end_stream(&self) -> bool236     pub fn is_end_stream(&self) -> bool {
237         self.flags.is_end_stream()
238     }
239 
set_end_stream(&mut self)240     pub fn set_end_stream(&mut self) {
241         self.flags.set_end_stream()
242     }
243 
is_over_size(&self) -> bool244     pub fn is_over_size(&self) -> bool {
245         self.header_block.is_over_size
246     }
247 
into_parts(self) -> (Pseudo, HeaderMap)248     pub fn into_parts(self) -> (Pseudo, HeaderMap) {
249         (self.header_block.pseudo, self.header_block.fields)
250     }
251 
252     #[cfg(feature = "unstable")]
pseudo_mut(&mut self) -> &mut Pseudo253     pub fn pseudo_mut(&mut self) -> &mut Pseudo {
254         &mut self.header_block.pseudo
255     }
256 
257     /// Whether it has status 1xx
is_informational(&self) -> bool258     pub(crate) fn is_informational(&self) -> bool {
259         self.header_block.pseudo.is_informational()
260     }
261 
fields(&self) -> &HeaderMap262     pub fn fields(&self) -> &HeaderMap {
263         &self.header_block.fields
264     }
265 
into_fields(self) -> HeaderMap266     pub fn into_fields(self) -> HeaderMap {
267         self.header_block.fields
268     }
269 
encode( self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option<Continuation>270     pub fn encode(
271         self,
272         encoder: &mut hpack::Encoder,
273         dst: &mut EncodeBuf<'_>,
274     ) -> Option<Continuation> {
275         // At this point, the `is_end_headers` flag should always be set
276         debug_assert!(self.flags.is_end_headers());
277 
278         // Get the HEADERS frame head
279         let head = self.head();
280 
281         self.header_block
282             .into_encoding(encoder)
283             .encode(&head, dst, |_| {})
284     }
285 
head(&self) -> Head286     fn head(&self) -> Head {
287         Head::new(Kind::Headers, self.flags.into(), self.stream_id)
288     }
289 }
290 
291 impl<T> From<Headers> for Frame<T> {
from(src: Headers) -> Self292     fn from(src: Headers) -> Self {
293         Frame::Headers(src)
294     }
295 }
296 
297 impl fmt::Debug for Headers {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result298     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
299         let mut builder = f.debug_struct("Headers");
300         builder
301             .field("stream_id", &self.stream_id)
302             .field("flags", &self.flags);
303 
304         if let Some(ref protocol) = self.header_block.pseudo.protocol {
305             builder.field("protocol", protocol);
306         }
307 
308         if let Some(ref dep) = self.stream_dep {
309             builder.field("stream_dep", dep);
310         }
311 
312         // `fields` and `pseudo` purposefully not included
313         builder.finish()
314     }
315 }
316 
317 // ===== util =====
318 
319 #[derive(Debug, PartialEq, Eq)]
320 pub struct ParseU64Error;
321 
parse_u64(src: &[u8]) -> Result<u64, ParseU64Error>322 pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> {
323     if src.len() > 19 {
324         // At danger for overflow...
325         return Err(ParseU64Error);
326     }
327 
328     let mut ret = 0;
329 
330     for &d in src {
331         if d < b'0' || d > b'9' {
332             return Err(ParseU64Error);
333         }
334 
335         ret *= 10;
336         ret += (d - b'0') as u64;
337     }
338 
339     Ok(ret)
340 }
341 
342 // ===== impl PushPromise =====
343 
344 #[derive(Debug)]
345 pub enum PushPromiseHeaderError {
346     InvalidContentLength(Result<u64, ParseU64Error>),
347     NotSafeAndCacheable,
348 }
349 
350 impl PushPromise {
new( stream_id: StreamId, promised_id: StreamId, pseudo: Pseudo, fields: HeaderMap, ) -> Self351     pub fn new(
352         stream_id: StreamId,
353         promised_id: StreamId,
354         pseudo: Pseudo,
355         fields: HeaderMap,
356     ) -> Self {
357         PushPromise {
358             flags: PushPromiseFlag::default(),
359             header_block: HeaderBlock {
360                 field_size: calculate_headermap_size(&fields),
361                 fields,
362                 is_over_size: false,
363                 pseudo,
364             },
365             promised_id,
366             stream_id,
367         }
368     }
369 
validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError>370     pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
371         use PushPromiseHeaderError::*;
372         // The spec has some requirements for promised request headers
373         // [https://httpwg.org/specs/rfc7540.html#PushRequests]
374 
375         // A promised request "that indicates the presence of a request body
376         // MUST reset the promised stream with a stream error"
377         if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
378             let parsed_length = parse_u64(content_length.as_bytes());
379             if parsed_length != Ok(0) {
380                 return Err(InvalidContentLength(parsed_length));
381             }
382         }
383         // "The server MUST include a method in the :method pseudo-header field
384         // that is safe and cacheable"
385         if !Self::safe_and_cacheable(req.method()) {
386             return Err(NotSafeAndCacheable);
387         }
388 
389         Ok(())
390     }
391 
safe_and_cacheable(method: &Method) -> bool392     fn safe_and_cacheable(method: &Method) -> bool {
393         // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
394         // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
395         method == Method::GET || method == Method::HEAD
396     }
397 
fields(&self) -> &HeaderMap398     pub fn fields(&self) -> &HeaderMap {
399         &self.header_block.fields
400     }
401 
402     #[cfg(feature = "unstable")]
into_fields(self) -> HeaderMap403     pub fn into_fields(self) -> HeaderMap {
404         self.header_block.fields
405     }
406 
407     /// Loads the push promise frame but doesn't actually do HPACK decoding.
408     ///
409     /// HPACK decoding is done in the `load_hpack` step.
load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error>410     pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
411         let flags = PushPromiseFlag(head.flag());
412         let mut pad = 0;
413 
414         if head.stream_id().is_zero() {
415             return Err(Error::InvalidStreamId);
416         }
417 
418         // Read the padding length
419         if flags.is_padded() {
420             if src.is_empty() {
421                 return Err(Error::MalformedMessage);
422             }
423 
424             // TODO: Ensure payload is sized correctly
425             pad = src[0] as usize;
426 
427             // Drop the padding
428             let _ = src.split_to(1);
429         }
430 
431         if src.len() < 5 {
432             return Err(Error::MalformedMessage);
433         }
434 
435         let (promised_id, _) = StreamId::parse(&src[..4]);
436         // Drop promised_id bytes
437         let _ = src.split_to(4);
438 
439         if pad > 0 {
440             if pad > src.len() {
441                 return Err(Error::TooMuchPadding);
442             }
443 
444             let len = src.len() - pad;
445             src.truncate(len);
446         }
447 
448         let frame = PushPromise {
449             flags,
450             header_block: HeaderBlock {
451                 fields: HeaderMap::new(),
452                 field_size: 0,
453                 is_over_size: false,
454                 pseudo: Pseudo::default(),
455             },
456             promised_id,
457             stream_id: head.stream_id(),
458         };
459         Ok((frame, src))
460     }
461 
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>462     pub fn load_hpack(
463         &mut self,
464         src: &mut BytesMut,
465         max_header_list_size: usize,
466         decoder: &mut hpack::Decoder,
467     ) -> Result<(), Error> {
468         self.header_block.load(src, max_header_list_size, decoder)
469     }
470 
stream_id(&self) -> StreamId471     pub fn stream_id(&self) -> StreamId {
472         self.stream_id
473     }
474 
promised_id(&self) -> StreamId475     pub fn promised_id(&self) -> StreamId {
476         self.promised_id
477     }
478 
is_end_headers(&self) -> bool479     pub fn is_end_headers(&self) -> bool {
480         self.flags.is_end_headers()
481     }
482 
set_end_headers(&mut self)483     pub fn set_end_headers(&mut self) {
484         self.flags.set_end_headers();
485     }
486 
is_over_size(&self) -> bool487     pub fn is_over_size(&self) -> bool {
488         self.header_block.is_over_size
489     }
490 
encode( self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option<Continuation>491     pub fn encode(
492         self,
493         encoder: &mut hpack::Encoder,
494         dst: &mut EncodeBuf<'_>,
495     ) -> Option<Continuation> {
496         // At this point, the `is_end_headers` flag should always be set
497         debug_assert!(self.flags.is_end_headers());
498 
499         let head = self.head();
500         let promised_id = self.promised_id;
501 
502         self.header_block
503             .into_encoding(encoder)
504             .encode(&head, dst, |dst| {
505                 dst.put_u32(promised_id.into());
506             })
507     }
508 
head(&self) -> Head509     fn head(&self) -> Head {
510         Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
511     }
512 
513     /// Consume `self`, returning the parts of the frame
into_parts(self) -> (Pseudo, HeaderMap)514     pub fn into_parts(self) -> (Pseudo, HeaderMap) {
515         (self.header_block.pseudo, self.header_block.fields)
516     }
517 }
518 
519 impl<T> From<PushPromise> for Frame<T> {
from(src: PushPromise) -> Self520     fn from(src: PushPromise) -> Self {
521         Frame::PushPromise(src)
522     }
523 }
524 
525 impl fmt::Debug for PushPromise {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result526     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
527         f.debug_struct("PushPromise")
528             .field("stream_id", &self.stream_id)
529             .field("promised_id", &self.promised_id)
530             .field("flags", &self.flags)
531             // `fields` and `pseudo` purposefully not included
532             .finish()
533     }
534 }
535 
536 // ===== impl Continuation =====
537 
538 impl Continuation {
head(&self) -> Head539     fn head(&self) -> Head {
540         Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
541     }
542 
encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation>543     pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
544         // Get the CONTINUATION frame head
545         let head = self.head();
546 
547         self.header_block.encode(&head, dst, |_| {})
548     }
549 }
550 
551 // ===== impl Pseudo =====
552 
553 impl Pseudo {
request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self554     pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
555         let parts = uri::Parts::from(uri);
556 
557         let mut path = parts
558             .path_and_query
559             .map(|v| BytesStr::from(v.as_str()))
560             .unwrap_or(BytesStr::from_static(""));
561 
562         match method {
563             Method::OPTIONS | Method::CONNECT => {}
564             _ if path.is_empty() => {
565                 path = BytesStr::from_static("/");
566             }
567             _ => {}
568         }
569 
570         let mut pseudo = Pseudo {
571             method: Some(method),
572             scheme: None,
573             authority: None,
574             path: Some(path).filter(|p| !p.is_empty()),
575             protocol,
576             status: None,
577         };
578 
579         // If the URI includes a scheme component, add it to the pseudo headers
580         //
581         // TODO: Scheme must be set...
582         if let Some(scheme) = parts.scheme {
583             pseudo.set_scheme(scheme);
584         }
585 
586         // If the URI includes an authority component, add it to the pseudo
587         // headers
588         if let Some(authority) = parts.authority {
589             pseudo.set_authority(BytesStr::from(authority.as_str()));
590         }
591 
592         pseudo
593     }
594 
response(status: StatusCode) -> Self595     pub fn response(status: StatusCode) -> Self {
596         Pseudo {
597             method: None,
598             scheme: None,
599             authority: None,
600             path: None,
601             protocol: None,
602             status: Some(status),
603         }
604     }
605 
606     #[cfg(feature = "unstable")]
set_status(&mut self, value: StatusCode)607     pub fn set_status(&mut self, value: StatusCode) {
608         self.status = Some(value);
609     }
610 
set_scheme(&mut self, scheme: uri::Scheme)611     pub fn set_scheme(&mut self, scheme: uri::Scheme) {
612         let bytes_str = match scheme.as_str() {
613             "http" => BytesStr::from_static("http"),
614             "https" => BytesStr::from_static("https"),
615             s => BytesStr::from(s),
616         };
617         self.scheme = Some(bytes_str);
618     }
619 
620     #[cfg(feature = "unstable")]
set_protocol(&mut self, protocol: Protocol)621     pub fn set_protocol(&mut self, protocol: Protocol) {
622         self.protocol = Some(protocol);
623     }
624 
set_authority(&mut self, authority: BytesStr)625     pub fn set_authority(&mut self, authority: BytesStr) {
626         self.authority = Some(authority);
627     }
628 
629     /// Whether it has status 1xx
is_informational(&self) -> bool630     pub(crate) fn is_informational(&self) -> bool {
631         self.status
632             .map_or(false, |status| status.is_informational())
633     }
634 }
635 
636 // ===== impl EncodingHeaderBlock =====
637 
638 impl EncodingHeaderBlock {
encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation> where F: FnOnce(&mut EncodeBuf<'_>),639     fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
640     where
641         F: FnOnce(&mut EncodeBuf<'_>),
642     {
643         let head_pos = dst.get_ref().len();
644 
645         // At this point, we don't know how big the h2 frame will be.
646         // So, we write the head with length 0, then write the body, and
647         // finally write the length once we know the size.
648         head.encode(0, dst);
649 
650         let payload_pos = dst.get_ref().len();
651 
652         f(dst);
653 
654         // Now, encode the header payload
655         let continuation = if self.hpack.len() > dst.remaining_mut() {
656             dst.put_slice(&self.hpack.split_to(dst.remaining_mut()));
657 
658             Some(Continuation {
659                 stream_id: head.stream_id(),
660                 header_block: self,
661             })
662         } else {
663             dst.put_slice(&self.hpack);
664 
665             None
666         };
667 
668         // Compute the header block length
669         let payload_len = (dst.get_ref().len() - payload_pos) as u64;
670 
671         // Write the frame length
672         let payload_len_be = payload_len.to_be_bytes();
673         assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
674         (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
675 
676         if continuation.is_some() {
677             // There will be continuation frames, so the `is_end_headers` flag
678             // must be unset
679             debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
680 
681             dst.get_mut()[head_pos + 4] -= END_HEADERS;
682         }
683 
684         continuation
685     }
686 }
687 
688 // ===== impl Iter =====
689 
690 impl Iterator for Iter {
691     type Item = hpack::Header<Option<HeaderName>>;
692 
next(&mut self) -> Option<Self::Item>693     fn next(&mut self) -> Option<Self::Item> {
694         use crate::hpack::Header::*;
695 
696         if let Some(ref mut pseudo) = self.pseudo {
697             if let Some(method) = pseudo.method.take() {
698                 return Some(Method(method));
699             }
700 
701             if let Some(scheme) = pseudo.scheme.take() {
702                 return Some(Scheme(scheme));
703             }
704 
705             if let Some(authority) = pseudo.authority.take() {
706                 return Some(Authority(authority));
707             }
708 
709             if let Some(path) = pseudo.path.take() {
710                 return Some(Path(path));
711             }
712 
713             if let Some(protocol) = pseudo.protocol.take() {
714                 return Some(Protocol(protocol));
715             }
716 
717             if let Some(status) = pseudo.status.take() {
718                 return Some(Status(status));
719             }
720         }
721 
722         self.pseudo = None;
723 
724         self.fields
725             .next()
726             .map(|(name, value)| Field { name, value })
727     }
728 }
729 
730 // ===== impl HeadersFlag =====
731 
732 impl HeadersFlag {
empty() -> HeadersFlag733     pub fn empty() -> HeadersFlag {
734         HeadersFlag(0)
735     }
736 
load(bits: u8) -> HeadersFlag737     pub fn load(bits: u8) -> HeadersFlag {
738         HeadersFlag(bits & ALL)
739     }
740 
is_end_stream(&self) -> bool741     pub fn is_end_stream(&self) -> bool {
742         self.0 & END_STREAM == END_STREAM
743     }
744 
set_end_stream(&mut self)745     pub fn set_end_stream(&mut self) {
746         self.0 |= END_STREAM;
747     }
748 
is_end_headers(&self) -> bool749     pub fn is_end_headers(&self) -> bool {
750         self.0 & END_HEADERS == END_HEADERS
751     }
752 
set_end_headers(&mut self)753     pub fn set_end_headers(&mut self) {
754         self.0 |= END_HEADERS;
755     }
756 
is_padded(&self) -> bool757     pub fn is_padded(&self) -> bool {
758         self.0 & PADDED == PADDED
759     }
760 
is_priority(&self) -> bool761     pub fn is_priority(&self) -> bool {
762         self.0 & PRIORITY == PRIORITY
763     }
764 }
765 
766 impl Default for HeadersFlag {
767     /// Returns a `HeadersFlag` value with `END_HEADERS` set.
default() -> Self768     fn default() -> Self {
769         HeadersFlag(END_HEADERS)
770     }
771 }
772 
773 impl From<HeadersFlag> for u8 {
from(src: HeadersFlag) -> u8774     fn from(src: HeadersFlag) -> u8 {
775         src.0
776     }
777 }
778 
779 impl fmt::Debug for HeadersFlag {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result780     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
781         util::debug_flags(fmt, self.0)
782             .flag_if(self.is_end_headers(), "END_HEADERS")
783             .flag_if(self.is_end_stream(), "END_STREAM")
784             .flag_if(self.is_padded(), "PADDED")
785             .flag_if(self.is_priority(), "PRIORITY")
786             .finish()
787     }
788 }
789 
790 // ===== impl PushPromiseFlag =====
791 
792 impl PushPromiseFlag {
empty() -> PushPromiseFlag793     pub fn empty() -> PushPromiseFlag {
794         PushPromiseFlag(0)
795     }
796 
load(bits: u8) -> PushPromiseFlag797     pub fn load(bits: u8) -> PushPromiseFlag {
798         PushPromiseFlag(bits & ALL)
799     }
800 
is_end_headers(&self) -> bool801     pub fn is_end_headers(&self) -> bool {
802         self.0 & END_HEADERS == END_HEADERS
803     }
804 
set_end_headers(&mut self)805     pub fn set_end_headers(&mut self) {
806         self.0 |= END_HEADERS;
807     }
808 
is_padded(&self) -> bool809     pub fn is_padded(&self) -> bool {
810         self.0 & PADDED == PADDED
811     }
812 }
813 
814 impl Default for PushPromiseFlag {
815     /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
default() -> Self816     fn default() -> Self {
817         PushPromiseFlag(END_HEADERS)
818     }
819 }
820 
821 impl From<PushPromiseFlag> for u8 {
from(src: PushPromiseFlag) -> u8822     fn from(src: PushPromiseFlag) -> u8 {
823         src.0
824     }
825 }
826 
827 impl fmt::Debug for PushPromiseFlag {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result828     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
829         util::debug_flags(fmt, self.0)
830             .flag_if(self.is_end_headers(), "END_HEADERS")
831             .flag_if(self.is_padded(), "PADDED")
832             .finish()
833     }
834 }
835 
836 // ===== HeaderBlock =====
837 
838 impl HeaderBlock {
load( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>839     fn load(
840         &mut self,
841         src: &mut BytesMut,
842         max_header_list_size: usize,
843         decoder: &mut hpack::Decoder,
844     ) -> Result<(), Error> {
845         let mut reg = !self.fields.is_empty();
846         let mut malformed = false;
847         let mut headers_size = self.calculate_header_list_size();
848 
849         macro_rules! set_pseudo {
850             ($field:ident, $val:expr) => {{
851                 if reg {
852                     tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
853                     malformed = true;
854                 } else if self.pseudo.$field.is_some() {
855                     tracing::trace!("load_hpack; header malformed -- repeated pseudo");
856                     malformed = true;
857                 } else {
858                     let __val = $val;
859                     headers_size +=
860                         decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
861                     if headers_size < max_header_list_size {
862                         self.pseudo.$field = Some(__val);
863                     } else if !self.is_over_size {
864                         tracing::trace!("load_hpack; header list size over max");
865                         self.is_over_size = true;
866                     }
867                 }
868             }};
869         }
870 
871         let mut cursor = Cursor::new(src);
872 
873         // If the header frame is malformed, we still have to continue decoding
874         // the headers. A malformed header frame is a stream level error, but
875         // the hpack state is connection level. In order to maintain correct
876         // state for other streams, the hpack decoding process must complete.
877         let res = decoder.decode(&mut cursor, |header| {
878             use crate::hpack::Header::*;
879 
880             match header {
881                 Field { name, value } => {
882                     // Connection level header fields are not supported and must
883                     // result in a protocol error.
884 
885                     if name == header::CONNECTION
886                         || name == header::TRANSFER_ENCODING
887                         || name == header::UPGRADE
888                         || name == "keep-alive"
889                         || name == "proxy-connection"
890                     {
891                         tracing::trace!("load_hpack; connection level header");
892                         malformed = true;
893                     } else if name == header::TE && value != "trailers" {
894                         tracing::trace!(
895                             "load_hpack; TE header not set to trailers; val={:?}",
896                             value
897                         );
898                         malformed = true;
899                     } else {
900                         reg = true;
901 
902                         headers_size += decoded_header_size(name.as_str().len(), value.len());
903                         if headers_size < max_header_list_size {
904                             self.field_size +=
905                                 decoded_header_size(name.as_str().len(), value.len());
906                             self.fields.append(name, value);
907                         } else if !self.is_over_size {
908                             tracing::trace!("load_hpack; header list size over max");
909                             self.is_over_size = true;
910                         }
911                     }
912                 }
913                 Authority(v) => set_pseudo!(authority, v),
914                 Method(v) => set_pseudo!(method, v),
915                 Scheme(v) => set_pseudo!(scheme, v),
916                 Path(v) => set_pseudo!(path, v),
917                 Protocol(v) => set_pseudo!(protocol, v),
918                 Status(v) => set_pseudo!(status, v),
919             }
920         });
921 
922         if let Err(e) = res {
923             tracing::trace!("hpack decoding error; err={:?}", e);
924             return Err(e.into());
925         }
926 
927         if malformed {
928             tracing::trace!("malformed message");
929             return Err(Error::MalformedMessage);
930         }
931 
932         Ok(())
933     }
934 
into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock935     fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
936         let mut hpack = BytesMut::new();
937         let headers = Iter {
938             pseudo: Some(self.pseudo),
939             fields: self.fields.into_iter(),
940         };
941 
942         encoder.encode(headers, &mut hpack);
943 
944         EncodingHeaderBlock {
945             hpack: hpack.freeze(),
946         }
947     }
948 
949     /// Calculates the size of the currently decoded header list.
950     ///
951     /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
952     ///
953     /// > The value is based on the uncompressed size of header fields,
954     /// > including the length of the name and value in octets plus an
955     /// > overhead of 32 octets for each header field.
calculate_header_list_size(&self) -> usize956     fn calculate_header_list_size(&self) -> usize {
957         macro_rules! pseudo_size {
958             ($name:ident) => {{
959                 self.pseudo
960                     .$name
961                     .as_ref()
962                     .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
963                     .unwrap_or(0)
964             }};
965         }
966 
967         pseudo_size!(method)
968             + pseudo_size!(scheme)
969             + pseudo_size!(status)
970             + pseudo_size!(authority)
971             + pseudo_size!(path)
972             + self.field_size
973     }
974 }
975 
calculate_headermap_size(map: &HeaderMap) -> usize976 fn calculate_headermap_size(map: &HeaderMap) -> usize {
977     map.iter()
978         .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
979         .sum::<usize>()
980 }
981 
decoded_header_size(name: usize, value: usize) -> usize982 fn decoded_header_size(name: usize, value: usize) -> usize {
983     name + value + 32
984 }
985 
986 #[cfg(test)]
987 mod test {
988     use super::*;
989     use crate::frame;
990     use crate::hpack::{huffman, Encoder};
991 
992     #[test]
test_nameless_header_at_resume()993     fn test_nameless_header_at_resume() {
994         let mut encoder = Encoder::default();
995         let mut dst = BytesMut::new();
996 
997         let headers = Headers::new(
998             StreamId::ZERO,
999             Default::default(),
1000             HeaderMap::from_iter(vec![
1001                 (
1002                     HeaderName::from_static("hello"),
1003                     HeaderValue::from_static("world"),
1004                 ),
1005                 (
1006                     HeaderName::from_static("hello"),
1007                     HeaderValue::from_static("zomg"),
1008                 ),
1009                 (
1010                     HeaderName::from_static("hello"),
1011                     HeaderValue::from_static("sup"),
1012                 ),
1013             ]),
1014         );
1015 
1016         let continuation = headers
1017             .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1018             .unwrap();
1019 
1020         assert_eq!(17, dst.len());
1021         assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1022         assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1023         assert_eq!("hello", huff_decode(&dst[11..15]));
1024         assert_eq!(0x80 | 4, dst[15]);
1025 
1026         let mut world = dst[16..17].to_owned();
1027 
1028         dst.clear();
1029 
1030         assert!(continuation
1031             .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1032             .is_none());
1033 
1034         world.extend_from_slice(&dst[9..12]);
1035         assert_eq!("world", huff_decode(&world));
1036 
1037         assert_eq!(24, dst.len());
1038         assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1039 
1040         // // Next is not indexed
1041         assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1042         assert_eq!("zomg", huff_decode(&dst[15..18]));
1043         assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1044         assert_eq!("sup", huff_decode(&dst[21..]));
1045     }
1046 
huff_decode(src: &[u8]) -> BytesMut1047     fn huff_decode(src: &[u8]) -> BytesMut {
1048         let mut buf = BytesMut::new();
1049         huffman::decode(src, &mut buf).unwrap()
1050     }
1051 }
1052