1 use super::encode::BUFFER_SIZE;
2 use crate::{metadata::MetadataValue, Status};
3 use bytes::{Buf, BytesMut};
4 #[cfg(feature = "gzip")]
5 use flate2::read::{GzDecoder, GzEncoder};
6 use std::fmt;
7 #[cfg(feature = "zstd")]
8 use zstd::stream::read::{Decoder, Encoder};
9
10 pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
11 pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
12
13 /// Struct used to configure which encodings are enabled on a server or channel.
14 #[derive(Debug, Default, Clone, Copy)]
15 pub struct EnabledCompressionEncodings {
16 #[cfg(feature = "gzip")]
17 pub(crate) gzip: bool,
18 #[cfg(feature = "zstd")]
19 pub(crate) zstd: bool,
20 }
21
22 impl EnabledCompressionEncodings {
23 /// Check if a [`CompressionEncoding`] is enabled.
is_enabled(&self, encoding: CompressionEncoding) -> bool24 pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
25 match encoding {
26 #[cfg(feature = "gzip")]
27 CompressionEncoding::Gzip => self.gzip,
28 #[cfg(feature = "zstd")]
29 CompressionEncoding::Zstd => self.zstd,
30 }
31 }
32
33 /// Enable a [`CompressionEncoding`].
enable(&mut self, encoding: CompressionEncoding)34 pub fn enable(&mut self, encoding: CompressionEncoding) {
35 match encoding {
36 #[cfg(feature = "gzip")]
37 CompressionEncoding::Gzip => self.gzip = true,
38 #[cfg(feature = "zstd")]
39 CompressionEncoding::Zstd => self.zstd = true,
40 }
41 }
42
into_accept_encoding_header_value(self) -> Option<http::HeaderValue>43 pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
44 match (self.is_gzip_enabled(), self.is_zstd_enabled()) {
45 (true, false) => Some(http::HeaderValue::from_static("gzip,identity")),
46 (false, true) => Some(http::HeaderValue::from_static("zstd,identity")),
47 (true, true) => Some(http::HeaderValue::from_static("gzip,zstd,identity")),
48 (false, false) => None,
49 }
50 }
51
52 #[cfg(feature = "gzip")]
is_gzip_enabled(&self) -> bool53 const fn is_gzip_enabled(&self) -> bool {
54 self.gzip
55 }
56
57 #[cfg(not(feature = "gzip"))]
is_gzip_enabled(&self) -> bool58 const fn is_gzip_enabled(&self) -> bool {
59 false
60 }
61
62 #[cfg(feature = "zstd")]
is_zstd_enabled(&self) -> bool63 const fn is_zstd_enabled(&self) -> bool {
64 self.zstd
65 }
66
67 #[cfg(not(feature = "zstd"))]
is_zstd_enabled(&self) -> bool68 const fn is_zstd_enabled(&self) -> bool {
69 false
70 }
71 }
72
73 /// The compression encodings Tonic supports.
74 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
75 #[non_exhaustive]
76 pub enum CompressionEncoding {
77 #[allow(missing_docs)]
78 #[cfg(feature = "gzip")]
79 #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
80 Gzip,
81 #[allow(missing_docs)]
82 #[cfg(feature = "zstd")]
83 #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
84 Zstd,
85 }
86
87 impl CompressionEncoding {
88 /// Based on the `grpc-accept-encoding` header, pick an encoding to use.
from_accept_encoding_header( map: &http::HeaderMap, enabled_encodings: EnabledCompressionEncodings, ) -> Option<Self>89 pub(crate) fn from_accept_encoding_header(
90 map: &http::HeaderMap,
91 enabled_encodings: EnabledCompressionEncodings,
92 ) -> Option<Self> {
93 if !enabled_encodings.is_gzip_enabled() && !enabled_encodings.is_zstd_enabled() {
94 return None;
95 }
96
97 let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
98 let header_value_str = header_value.to_str().ok()?;
99
100 split_by_comma(header_value_str).find_map(|value| match value {
101 #[cfg(feature = "gzip")]
102 "gzip" => Some(CompressionEncoding::Gzip),
103 #[cfg(feature = "zstd")]
104 "zstd" => Some(CompressionEncoding::Zstd),
105 _ => None,
106 })
107 }
108
109 /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
from_encoding_header( map: &http::HeaderMap, enabled_encodings: EnabledCompressionEncodings, ) -> Result<Option<Self>, Status>110 pub(crate) fn from_encoding_header(
111 map: &http::HeaderMap,
112 enabled_encodings: EnabledCompressionEncodings,
113 ) -> Result<Option<Self>, Status> {
114 let header_value = if let Some(value) = map.get(ENCODING_HEADER) {
115 value
116 } else {
117 return Ok(None);
118 };
119
120 let header_value_str = if let Ok(value) = header_value.to_str() {
121 value
122 } else {
123 return Ok(None);
124 };
125
126 match header_value_str {
127 #[cfg(feature = "gzip")]
128 "gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
129 Ok(Some(CompressionEncoding::Gzip))
130 }
131 #[cfg(feature = "zstd")]
132 "zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
133 Ok(Some(CompressionEncoding::Zstd))
134 }
135 "identity" => Ok(None),
136 other => {
137 let mut status = Status::unimplemented(format!(
138 "Content is compressed with `{}` which isn't supported",
139 other
140 ));
141
142 let header_value = enabled_encodings
143 .into_accept_encoding_header_value()
144 .map(MetadataValue::unchecked_from_header_value)
145 .unwrap_or_else(|| MetadataValue::from_static("identity"));
146 status
147 .metadata_mut()
148 .insert(ACCEPT_ENCODING_HEADER, header_value);
149
150 Err(status)
151 }
152 }
153 }
154
155 #[allow(missing_docs)]
156 #[cfg(any(feature = "gzip", feature = "zstd"))]
as_str(&self) -> &'static str157 pub(crate) fn as_str(&self) -> &'static str {
158 match self {
159 #[cfg(feature = "gzip")]
160 CompressionEncoding::Gzip => "gzip",
161 #[cfg(feature = "zstd")]
162 CompressionEncoding::Zstd => "zstd",
163 }
164 }
165
166 #[cfg(any(feature = "gzip", feature = "zstd"))]
into_header_value(self) -> http::HeaderValue167 pub(crate) fn into_header_value(self) -> http::HeaderValue {
168 http::HeaderValue::from_static(self.as_str())
169 }
170
encodings() -> &'static [Self]171 pub(crate) fn encodings() -> &'static [Self] {
172 &[
173 #[cfg(feature = "gzip")]
174 CompressionEncoding::Gzip,
175 #[cfg(feature = "zstd")]
176 CompressionEncoding::Zstd,
177 ]
178 }
179 }
180
181 impl fmt::Display for CompressionEncoding {
182 #[allow(unused_variables)]
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 match *self {
185 #[cfg(feature = "gzip")]
186 CompressionEncoding::Gzip => write!(f, "gzip"),
187 #[cfg(feature = "zstd")]
188 CompressionEncoding::Zstd => write!(f, "zstd"),
189 }
190 }
191 }
192
split_by_comma(s: &str) -> impl Iterator<Item = &str>193 fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
194 s.trim().split(',').map(|s| s.trim())
195 }
196
197 /// Compress `len` bytes from `decompressed_buf` into `out_buf`.
198 #[allow(unused_variables, unreachable_code)]
compress( encoding: CompressionEncoding, decompressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error>199 pub(crate) fn compress(
200 encoding: CompressionEncoding,
201 decompressed_buf: &mut BytesMut,
202 out_buf: &mut BytesMut,
203 len: usize,
204 ) -> Result<(), std::io::Error> {
205 let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
206 out_buf.reserve(capacity);
207
208 #[cfg(any(feature = "gzip", feature = "zstd"))]
209 let mut out_writer = bytes::BufMut::writer(out_buf);
210
211 match encoding {
212 #[cfg(feature = "gzip")]
213 CompressionEncoding::Gzip => {
214 let mut gzip_encoder = GzEncoder::new(
215 &decompressed_buf[0..len],
216 // FIXME: support customizing the compression level
217 flate2::Compression::new(6),
218 );
219 std::io::copy(&mut gzip_encoder, &mut out_writer)?;
220 }
221 #[cfg(feature = "zstd")]
222 CompressionEncoding::Zstd => {
223 let mut zstd_encoder = Encoder::new(
224 &decompressed_buf[0..len],
225 // FIXME: support customizing the compression level
226 zstd::DEFAULT_COMPRESSION_LEVEL,
227 )?;
228 std::io::copy(&mut zstd_encoder, &mut out_writer)?;
229 }
230 }
231
232 decompressed_buf.advance(len);
233
234 Ok(())
235 }
236
237 /// Decompress `len` bytes from `compressed_buf` into `out_buf`.
238 #[allow(unused_variables, unreachable_code)]
decompress( encoding: CompressionEncoding, compressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error>239 pub(crate) fn decompress(
240 encoding: CompressionEncoding,
241 compressed_buf: &mut BytesMut,
242 out_buf: &mut BytesMut,
243 len: usize,
244 ) -> Result<(), std::io::Error> {
245 let estimate_decompressed_len = len * 2;
246 let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
247 out_buf.reserve(capacity);
248
249 #[cfg(any(feature = "gzip", feature = "zstd"))]
250 let mut out_writer = bytes::BufMut::writer(out_buf);
251
252 match encoding {
253 #[cfg(feature = "gzip")]
254 CompressionEncoding::Gzip => {
255 let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
256 std::io::copy(&mut gzip_decoder, &mut out_writer)?;
257 }
258 #[cfg(feature = "zstd")]
259 CompressionEncoding::Zstd => {
260 let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
261 std::io::copy(&mut zstd_decoder, &mut out_writer)?;
262 }
263 }
264
265 compressed_buf.advance(len);
266
267 Ok(())
268 }
269
270 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
271 pub(crate) enum SingleMessageCompressionOverride {
272 /// Inherit whatever compression is already configured. If the stream is compressed this
273 /// message will also be configured.
274 ///
275 /// This is the default.
276 Inherit,
277 /// Don't compress this message, even if compression is enabled on the stream.
278 Disable,
279 }
280
281 impl Default for SingleMessageCompressionOverride {
default() -> Self282 fn default() -> Self {
283 Self::Inherit
284 }
285 }
286