1 use super::encode::BUFFER_SIZE;
2 use crate::{metadata::MetadataValue, Status};
3 use bytes::{Buf, BytesMut};
4 #[cfg(feature = "gzip")]
5 use flate2::read::{GzDecoder, GzEncoder};
6 use std::fmt;
7 #[cfg(feature = "zstd")]
8 use zstd::stream::read::{Decoder, Encoder};
9 
10 pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
11 pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
12 
13 /// Struct used to configure which encodings are enabled on a server or channel.
14 #[derive(Debug, Default, Clone, Copy)]
15 pub struct EnabledCompressionEncodings {
16     #[cfg(feature = "gzip")]
17     pub(crate) gzip: bool,
18     #[cfg(feature = "zstd")]
19     pub(crate) zstd: bool,
20 }
21 
22 impl EnabledCompressionEncodings {
23     /// Check if a [`CompressionEncoding`] is enabled.
is_enabled(&self, encoding: CompressionEncoding) -> bool24     pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
25         match encoding {
26             #[cfg(feature = "gzip")]
27             CompressionEncoding::Gzip => self.gzip,
28             #[cfg(feature = "zstd")]
29             CompressionEncoding::Zstd => self.zstd,
30         }
31     }
32 
33     /// Enable a [`CompressionEncoding`].
enable(&mut self, encoding: CompressionEncoding)34     pub fn enable(&mut self, encoding: CompressionEncoding) {
35         match encoding {
36             #[cfg(feature = "gzip")]
37             CompressionEncoding::Gzip => self.gzip = true,
38             #[cfg(feature = "zstd")]
39             CompressionEncoding::Zstd => self.zstd = true,
40         }
41     }
42 
into_accept_encoding_header_value(self) -> Option<http::HeaderValue>43     pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
44         match (self.is_gzip_enabled(), self.is_zstd_enabled()) {
45             (true, false) => Some(http::HeaderValue::from_static("gzip,identity")),
46             (false, true) => Some(http::HeaderValue::from_static("zstd,identity")),
47             (true, true) => Some(http::HeaderValue::from_static("gzip,zstd,identity")),
48             (false, false) => None,
49         }
50     }
51 
52     #[cfg(feature = "gzip")]
is_gzip_enabled(&self) -> bool53     const fn is_gzip_enabled(&self) -> bool {
54         self.gzip
55     }
56 
57     #[cfg(not(feature = "gzip"))]
is_gzip_enabled(&self) -> bool58     const fn is_gzip_enabled(&self) -> bool {
59         false
60     }
61 
62     #[cfg(feature = "zstd")]
is_zstd_enabled(&self) -> bool63     const fn is_zstd_enabled(&self) -> bool {
64         self.zstd
65     }
66 
67     #[cfg(not(feature = "zstd"))]
is_zstd_enabled(&self) -> bool68     const fn is_zstd_enabled(&self) -> bool {
69         false
70     }
71 }
72 
73 /// The compression encodings Tonic supports.
74 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
75 #[non_exhaustive]
76 pub enum CompressionEncoding {
77     #[allow(missing_docs)]
78     #[cfg(feature = "gzip")]
79     #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
80     Gzip,
81     #[allow(missing_docs)]
82     #[cfg(feature = "zstd")]
83     #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
84     Zstd,
85 }
86 
87 impl CompressionEncoding {
88     /// Based on the `grpc-accept-encoding` header, pick an encoding to use.
from_accept_encoding_header( map: &http::HeaderMap, enabled_encodings: EnabledCompressionEncodings, ) -> Option<Self>89     pub(crate) fn from_accept_encoding_header(
90         map: &http::HeaderMap,
91         enabled_encodings: EnabledCompressionEncodings,
92     ) -> Option<Self> {
93         if !enabled_encodings.is_gzip_enabled() && !enabled_encodings.is_zstd_enabled() {
94             return None;
95         }
96 
97         let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
98         let header_value_str = header_value.to_str().ok()?;
99 
100         split_by_comma(header_value_str).find_map(|value| match value {
101             #[cfg(feature = "gzip")]
102             "gzip" => Some(CompressionEncoding::Gzip),
103             #[cfg(feature = "zstd")]
104             "zstd" => Some(CompressionEncoding::Zstd),
105             _ => None,
106         })
107     }
108 
109     /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
from_encoding_header( map: &http::HeaderMap, enabled_encodings: EnabledCompressionEncodings, ) -> Result<Option<Self>, Status>110     pub(crate) fn from_encoding_header(
111         map: &http::HeaderMap,
112         enabled_encodings: EnabledCompressionEncodings,
113     ) -> Result<Option<Self>, Status> {
114         let header_value = if let Some(value) = map.get(ENCODING_HEADER) {
115             value
116         } else {
117             return Ok(None);
118         };
119 
120         let header_value_str = if let Ok(value) = header_value.to_str() {
121             value
122         } else {
123             return Ok(None);
124         };
125 
126         match header_value_str {
127             #[cfg(feature = "gzip")]
128             "gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
129                 Ok(Some(CompressionEncoding::Gzip))
130             }
131             #[cfg(feature = "zstd")]
132             "zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
133                 Ok(Some(CompressionEncoding::Zstd))
134             }
135             "identity" => Ok(None),
136             other => {
137                 let mut status = Status::unimplemented(format!(
138                     "Content is compressed with `{}` which isn't supported",
139                     other
140                 ));
141 
142                 let header_value = enabled_encodings
143                     .into_accept_encoding_header_value()
144                     .map(MetadataValue::unchecked_from_header_value)
145                     .unwrap_or_else(|| MetadataValue::from_static("identity"));
146                 status
147                     .metadata_mut()
148                     .insert(ACCEPT_ENCODING_HEADER, header_value);
149 
150                 Err(status)
151             }
152         }
153     }
154 
155     #[allow(missing_docs)]
156     #[cfg(any(feature = "gzip", feature = "zstd"))]
as_str(&self) -> &'static str157     pub(crate) fn as_str(&self) -> &'static str {
158         match self {
159             #[cfg(feature = "gzip")]
160             CompressionEncoding::Gzip => "gzip",
161             #[cfg(feature = "zstd")]
162             CompressionEncoding::Zstd => "zstd",
163         }
164     }
165 
166     #[cfg(any(feature = "gzip", feature = "zstd"))]
into_header_value(self) -> http::HeaderValue167     pub(crate) fn into_header_value(self) -> http::HeaderValue {
168         http::HeaderValue::from_static(self.as_str())
169     }
170 
encodings() -> &'static [Self]171     pub(crate) fn encodings() -> &'static [Self] {
172         &[
173             #[cfg(feature = "gzip")]
174             CompressionEncoding::Gzip,
175             #[cfg(feature = "zstd")]
176             CompressionEncoding::Zstd,
177         ]
178     }
179 }
180 
181 impl fmt::Display for CompressionEncoding {
182     #[allow(unused_variables)]
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result183     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184         match *self {
185             #[cfg(feature = "gzip")]
186             CompressionEncoding::Gzip => write!(f, "gzip"),
187             #[cfg(feature = "zstd")]
188             CompressionEncoding::Zstd => write!(f, "zstd"),
189         }
190     }
191 }
192 
split_by_comma(s: &str) -> impl Iterator<Item = &str>193 fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
194     s.trim().split(',').map(|s| s.trim())
195 }
196 
197 /// Compress `len` bytes from `decompressed_buf` into `out_buf`.
198 #[allow(unused_variables, unreachable_code)]
compress( encoding: CompressionEncoding, decompressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error>199 pub(crate) fn compress(
200     encoding: CompressionEncoding,
201     decompressed_buf: &mut BytesMut,
202     out_buf: &mut BytesMut,
203     len: usize,
204 ) -> Result<(), std::io::Error> {
205     let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
206     out_buf.reserve(capacity);
207 
208     #[cfg(any(feature = "gzip", feature = "zstd"))]
209     let mut out_writer = bytes::BufMut::writer(out_buf);
210 
211     match encoding {
212         #[cfg(feature = "gzip")]
213         CompressionEncoding::Gzip => {
214             let mut gzip_encoder = GzEncoder::new(
215                 &decompressed_buf[0..len],
216                 // FIXME: support customizing the compression level
217                 flate2::Compression::new(6),
218             );
219             std::io::copy(&mut gzip_encoder, &mut out_writer)?;
220         }
221         #[cfg(feature = "zstd")]
222         CompressionEncoding::Zstd => {
223             let mut zstd_encoder = Encoder::new(
224                 &decompressed_buf[0..len],
225                 // FIXME: support customizing the compression level
226                 zstd::DEFAULT_COMPRESSION_LEVEL,
227             )?;
228             std::io::copy(&mut zstd_encoder, &mut out_writer)?;
229         }
230     }
231 
232     decompressed_buf.advance(len);
233 
234     Ok(())
235 }
236 
237 /// Decompress `len` bytes from `compressed_buf` into `out_buf`.
238 #[allow(unused_variables, unreachable_code)]
decompress( encoding: CompressionEncoding, compressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error>239 pub(crate) fn decompress(
240     encoding: CompressionEncoding,
241     compressed_buf: &mut BytesMut,
242     out_buf: &mut BytesMut,
243     len: usize,
244 ) -> Result<(), std::io::Error> {
245     let estimate_decompressed_len = len * 2;
246     let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
247     out_buf.reserve(capacity);
248 
249     #[cfg(any(feature = "gzip", feature = "zstd"))]
250     let mut out_writer = bytes::BufMut::writer(out_buf);
251 
252     match encoding {
253         #[cfg(feature = "gzip")]
254         CompressionEncoding::Gzip => {
255             let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
256             std::io::copy(&mut gzip_decoder, &mut out_writer)?;
257         }
258         #[cfg(feature = "zstd")]
259         CompressionEncoding::Zstd => {
260             let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
261             std::io::copy(&mut zstd_decoder, &mut out_writer)?;
262         }
263     }
264 
265     compressed_buf.advance(len);
266 
267     Ok(())
268 }
269 
270 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
271 pub(crate) enum SingleMessageCompressionOverride {
272     /// Inherit whatever compression is already configured. If the stream is compressed this
273     /// message will also be configured.
274     ///
275     /// This is the default.
276     Inherit,
277     /// Don't compress this message, even if compression is enabled on the stream.
278     Disable,
279 }
280 
281 impl Default for SingleMessageCompressionOverride {
default() -> Self282     fn default() -> Self {
283         Self::Inherit
284     }
285 }
286