1 use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings};
2 use crate::{
3     body::BoxBody,
4     client::GrpcService,
5     codec::{encode_client, Codec, Decoder, Streaming},
6     request::SanitizeHeaders,
7     Code, Request, Response, Status,
8 };
9 use http::{
10     header::{HeaderValue, CONTENT_TYPE, TE},
11     uri::{PathAndQuery, Uri},
12 };
13 use http_body::Body;
14 use std::{fmt, future};
15 use tokio_stream::{Stream, StreamExt};
16 
17 /// A gRPC client dispatcher.
18 ///
19 /// This will wrap some inner [`GrpcService`] and will encode/decode
20 /// messages via the provided codec.
21 ///
22 /// Each request method takes a [`Request`], a [`PathAndQuery`], and a
23 /// [`Codec`]. The request contains the message to send via the
24 /// [`Codec::encoder`]. The path determines the fully qualified path
25 /// that will be append to the outgoing uri. The path must follow
26 /// the conventions explained in the [gRPC protocol definition] under `Path →`. An
27 /// example of this path could look like `/greeter.Greeter/SayHello`.
28 ///
29 /// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
30 pub struct Grpc<T> {
31     inner: T,
32     config: GrpcConfig,
33 }
34 
35 struct GrpcConfig {
36     origin: Uri,
37     /// Which compression encodings does the client accept?
38     accept_compression_encodings: EnabledCompressionEncodings,
39     /// The compression encoding that will be applied to requests.
40     send_compression_encodings: Option<CompressionEncoding>,
41     /// Limits the maximum size of a decoded message.
42     max_decoding_message_size: Option<usize>,
43     /// Limits the maximum size of an encoded message.
44     max_encoding_message_size: Option<usize>,
45 }
46 
47 impl<T> Grpc<T> {
48     /// Creates a new gRPC client with the provided [`GrpcService`].
new(inner: T) -> Self49     pub fn new(inner: T) -> Self {
50         Self::with_origin(inner, Uri::default())
51     }
52 
53     /// Creates a new gRPC client with the provided [`GrpcService`] and `Uri`.
54     ///
55     /// The provided Uri will use only the scheme and authority parts as the
56     /// path_and_query portion will be set for each method.
with_origin(inner: T, origin: Uri) -> Self57     pub fn with_origin(inner: T, origin: Uri) -> Self {
58         Self {
59             inner,
60             config: GrpcConfig {
61                 origin,
62                 send_compression_encodings: None,
63                 accept_compression_encodings: EnabledCompressionEncodings::default(),
64                 max_decoding_message_size: None,
65                 max_encoding_message_size: None,
66             },
67         }
68     }
69 
70     /// Compress requests with the provided encoding.
71     ///
72     /// Requires the server to accept the specified encoding, otherwise it might return an error.
73     ///
74     /// # Example
75     ///
76     /// The most common way of using this is through a client generated by tonic-build:
77     ///
78     /// ```rust
79     /// use tonic::transport::Channel;
80     /// # enum CompressionEncoding { Gzip }
81     /// # struct TestClient<T>(T);
82     /// # impl<T> TestClient<T> {
83     /// #     fn new(channel: T) -> Self { Self(channel) }
84     /// #     fn send_compressed(self, _: CompressionEncoding) -> Self { self }
85     /// # }
86     ///
87     /// # async {
88     /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
89     ///     .connect()
90     ///     .await
91     ///     .unwrap();
92     ///
93     /// let client = TestClient::new(channel).send_compressed(CompressionEncoding::Gzip);
94     /// # };
95     /// ```
send_compressed(mut self, encoding: CompressionEncoding) -> Self96     pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
97         self.config.send_compression_encodings = Some(encoding);
98         self
99     }
100 
101     /// Enable accepting compressed responses.
102     ///
103     /// Requires the server to also support sending compressed responses.
104     ///
105     /// # Example
106     ///
107     /// The most common way of using this is through a client generated by tonic-build:
108     ///
109     /// ```rust
110     /// use tonic::transport::Channel;
111     /// # enum CompressionEncoding { Gzip }
112     /// # struct TestClient<T>(T);
113     /// # impl<T> TestClient<T> {
114     /// #     fn new(channel: T) -> Self { Self(channel) }
115     /// #     fn accept_compressed(self, _: CompressionEncoding) -> Self { self }
116     /// # }
117     ///
118     /// # async {
119     /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
120     ///     .connect()
121     ///     .await
122     ///     .unwrap();
123     ///
124     /// let client = TestClient::new(channel).accept_compressed(CompressionEncoding::Gzip);
125     /// # };
126     /// ```
accept_compressed(mut self, encoding: CompressionEncoding) -> Self127     pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
128         self.config.accept_compression_encodings.enable(encoding);
129         self
130     }
131 
132     /// Limits the maximum size of a decoded message.
133     ///
134     /// # Example
135     ///
136     /// The most common way of using this is through a client generated by tonic-build:
137     ///
138     /// ```rust
139     /// use tonic::transport::Channel;
140     /// # struct TestClient<T>(T);
141     /// # impl<T> TestClient<T> {
142     /// #     fn new(channel: T) -> Self { Self(channel) }
143     /// #     fn max_decoding_message_size(self, _: usize) -> Self { self }
144     /// # }
145     ///
146     /// # async {
147     /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
148     ///     .connect()
149     ///     .await
150     ///     .unwrap();
151     ///
152     /// // Set the limit to 2MB, Defaults to 4MB.
153     /// let limit = 2 * 1024 * 1024;
154     /// let client = TestClient::new(channel).max_decoding_message_size(limit);
155     /// # };
156     /// ```
max_decoding_message_size(mut self, limit: usize) -> Self157     pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
158         self.config.max_decoding_message_size = Some(limit);
159         self
160     }
161 
162     /// Limits the maximum size of an ecoded message.
163     ///
164     /// # Example
165     ///
166     /// The most common way of using this is through a client generated by tonic-build:
167     ///
168     /// ```rust
169     /// use tonic::transport::Channel;
170     /// # struct TestClient<T>(T);
171     /// # impl<T> TestClient<T> {
172     /// #     fn new(channel: T) -> Self { Self(channel) }
173     /// #     fn max_encoding_message_size(self, _: usize) -> Self { self }
174     /// # }
175     ///
176     /// # async {
177     /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
178     ///     .connect()
179     ///     .await
180     ///     .unwrap();
181     ///
182     /// // Set the limit to 2MB, Defaults to 4MB.
183     /// let limit = 2 * 1024 * 1024;
184     /// let client = TestClient::new(channel).max_encoding_message_size(limit);
185     /// # };
186     /// ```
max_encoding_message_size(mut self, limit: usize) -> Self187     pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
188         self.config.max_encoding_message_size = Some(limit);
189         self
190     }
191 
192     /// Check if the inner [`GrpcService`] is able to accept a  new request.
193     ///
194     /// This will call [`GrpcService::poll_ready`] until it returns ready or
195     /// an error. If this returns ready the inner [`GrpcService`] is ready to
196     /// accept one more request.
ready(&mut self) -> Result<(), T::Error> where T: GrpcService<BoxBody>,197     pub async fn ready(&mut self) -> Result<(), T::Error>
198     where
199         T: GrpcService<BoxBody>,
200     {
201         future::poll_fn(|cx| self.inner.poll_ready(cx)).await
202     }
203 
204     /// Send a single unary gRPC request.
unary<M1, M2, C>( &mut self, request: Request<M1>, path: PathAndQuery, codec: C, ) -> Result<Response<M2>, Status> where T: GrpcService<BoxBody>, T::ResponseBody: Body + Send + 'static, <T::ResponseBody as Body>::Error: Into<crate::Error>, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,205     pub async fn unary<M1, M2, C>(
206         &mut self,
207         request: Request<M1>,
208         path: PathAndQuery,
209         codec: C,
210     ) -> Result<Response<M2>, Status>
211     where
212         T: GrpcService<BoxBody>,
213         T::ResponseBody: Body + Send + 'static,
214         <T::ResponseBody as Body>::Error: Into<crate::Error>,
215         C: Codec<Encode = M1, Decode = M2>,
216         M1: Send + Sync + 'static,
217         M2: Send + Sync + 'static,
218     {
219         let request = request.map(|m| tokio_stream::once(m));
220         self.client_streaming(request, path, codec).await
221     }
222 
223     /// Send a client side streaming gRPC request.
client_streaming<S, M1, M2, C>( &mut self, request: Request<S>, path: PathAndQuery, codec: C, ) -> Result<Response<M2>, Status> where T: GrpcService<BoxBody>, T::ResponseBody: Body + Send + 'static, <T::ResponseBody as Body>::Error: Into<crate::Error>, S: Stream<Item = M1> + Send + 'static, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,224     pub async fn client_streaming<S, M1, M2, C>(
225         &mut self,
226         request: Request<S>,
227         path: PathAndQuery,
228         codec: C,
229     ) -> Result<Response<M2>, Status>
230     where
231         T: GrpcService<BoxBody>,
232         T::ResponseBody: Body + Send + 'static,
233         <T::ResponseBody as Body>::Error: Into<crate::Error>,
234         S: Stream<Item = M1> + Send + 'static,
235         C: Codec<Encode = M1, Decode = M2>,
236         M1: Send + Sync + 'static,
237         M2: Send + Sync + 'static,
238     {
239         let (mut parts, body, extensions) =
240             self.streaming(request, path, codec).await?.into_parts();
241 
242         tokio::pin!(body);
243 
244         let message = body
245             .try_next()
246             .await
247             .map_err(|mut status| {
248                 status.metadata_mut().merge(parts.clone());
249                 status
250             })?
251             .ok_or_else(|| Status::new(Code::Internal, "Missing response message."))?;
252 
253         if let Some(trailers) = body.trailers().await? {
254             parts.merge(trailers);
255         }
256 
257         Ok(Response::from_parts(parts, message, extensions))
258     }
259 
260     /// Send a server side streaming gRPC request.
server_streaming<M1, M2, C>( &mut self, request: Request<M1>, path: PathAndQuery, codec: C, ) -> Result<Response<Streaming<M2>>, Status> where T: GrpcService<BoxBody>, T::ResponseBody: Body + Send + 'static, <T::ResponseBody as Body>::Error: Into<crate::Error>, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,261     pub async fn server_streaming<M1, M2, C>(
262         &mut self,
263         request: Request<M1>,
264         path: PathAndQuery,
265         codec: C,
266     ) -> Result<Response<Streaming<M2>>, Status>
267     where
268         T: GrpcService<BoxBody>,
269         T::ResponseBody: Body + Send + 'static,
270         <T::ResponseBody as Body>::Error: Into<crate::Error>,
271         C: Codec<Encode = M1, Decode = M2>,
272         M1: Send + Sync + 'static,
273         M2: Send + Sync + 'static,
274     {
275         let request = request.map(|m| tokio_stream::once(m));
276         self.streaming(request, path, codec).await
277     }
278 
279     /// Send a bi-directional streaming gRPC request.
streaming<S, M1, M2, C>( &mut self, request: Request<S>, path: PathAndQuery, mut codec: C, ) -> Result<Response<Streaming<M2>>, Status> where T: GrpcService<BoxBody>, T::ResponseBody: Body + Send + 'static, <T::ResponseBody as Body>::Error: Into<crate::Error>, S: Stream<Item = M1> + Send + 'static, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,280     pub async fn streaming<S, M1, M2, C>(
281         &mut self,
282         request: Request<S>,
283         path: PathAndQuery,
284         mut codec: C,
285     ) -> Result<Response<Streaming<M2>>, Status>
286     where
287         T: GrpcService<BoxBody>,
288         T::ResponseBody: Body + Send + 'static,
289         <T::ResponseBody as Body>::Error: Into<crate::Error>,
290         S: Stream<Item = M1> + Send + 'static,
291         C: Codec<Encode = M1, Decode = M2>,
292         M1: Send + Sync + 'static,
293         M2: Send + Sync + 'static,
294     {
295         let request = request
296             .map(|s| {
297                 encode_client(
298                     codec.encoder(),
299                     s,
300                     self.config.send_compression_encodings,
301                     self.config.max_encoding_message_size,
302                 )
303             })
304             .map(BoxBody::new);
305 
306         let request = self.config.prepare_request(request, path);
307 
308         let response = self
309             .inner
310             .call(request)
311             .await
312             .map_err(Status::from_error_generic)?;
313 
314         let decoder = codec.decoder();
315 
316         self.create_response(decoder, response)
317     }
318 
319     // Keeping this code in a separate function from Self::streaming lets functions that return the
320     // same output share the generated binary code
create_response<M2>( &self, decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static, response: http::Response<T::ResponseBody>, ) -> Result<Response<Streaming<M2>>, Status> where T: GrpcService<BoxBody>, T::ResponseBody: Body + Send + 'static, <T::ResponseBody as Body>::Error: Into<crate::Error>,321     fn create_response<M2>(
322         &self,
323         decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static,
324         response: http::Response<T::ResponseBody>,
325     ) -> Result<Response<Streaming<M2>>, Status>
326     where
327         T: GrpcService<BoxBody>,
328         T::ResponseBody: Body + Send + 'static,
329         <T::ResponseBody as Body>::Error: Into<crate::Error>,
330     {
331         let encoding = CompressionEncoding::from_encoding_header(
332             response.headers(),
333             self.config.accept_compression_encodings,
334         )?;
335 
336         let status_code = response.status();
337         let trailers_only_status = Status::from_header_map(response.headers());
338 
339         // We do not need to check for trailers if the `grpc-status` header is present
340         // with a valid code.
341         let expect_additional_trailers = if let Some(status) = trailers_only_status {
342             if status.code() != Code::Ok {
343                 return Err(status);
344             }
345 
346             false
347         } else {
348             true
349         };
350 
351         let response = response.map(|body| {
352             if expect_additional_trailers {
353                 Streaming::new_response(
354                     decoder,
355                     body,
356                     status_code,
357                     encoding,
358                     self.config.max_decoding_message_size,
359                 )
360             } else {
361                 Streaming::new_empty(decoder, body)
362             }
363         });
364 
365         Ok(Response::from_http(response))
366     }
367 }
368 
369 impl GrpcConfig {
prepare_request( &self, request: Request<BoxBody>, path: PathAndQuery, ) -> http::Request<BoxBody>370     fn prepare_request(
371         &self,
372         request: Request<BoxBody>,
373         path: PathAndQuery,
374     ) -> http::Request<BoxBody> {
375         let mut parts = self.origin.clone().into_parts();
376 
377         match &parts.path_and_query {
378             Some(pnq) if pnq != "/" => {
379                 parts.path_and_query = Some(
380                     format!("{}{}", pnq.path(), path)
381                         .parse()
382                         .expect("must form valid path_and_query"),
383                 )
384             }
385             _ => {
386                 parts.path_and_query = Some(path);
387             }
388         }
389 
390         let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri");
391 
392         let mut request = request.into_http(
393             uri,
394             http::Method::POST,
395             http::Version::HTTP_2,
396             SanitizeHeaders::Yes,
397         );
398 
399         // Add the gRPC related HTTP headers
400         request
401             .headers_mut()
402             .insert(TE, HeaderValue::from_static("trailers"));
403 
404         // Set the content type
405         request
406             .headers_mut()
407             .insert(CONTENT_TYPE, HeaderValue::from_static("application/grpc"));
408 
409         #[cfg(any(feature = "gzip", feature = "zstd"))]
410         if let Some(encoding) = self.send_compression_encodings {
411             request.headers_mut().insert(
412                 crate::codec::compression::ENCODING_HEADER,
413                 encoding.into_header_value(),
414             );
415         }
416 
417         if let Some(header_value) = self
418             .accept_compression_encodings
419             .into_accept_encoding_header_value()
420         {
421             request.headers_mut().insert(
422                 crate::codec::compression::ACCEPT_ENCODING_HEADER,
423                 header_value,
424             );
425         }
426 
427         request
428     }
429 }
430 
431 impl<T: Clone> Clone for Grpc<T> {
clone(&self) -> Self432     fn clone(&self) -> Self {
433         Self {
434             inner: self.inner.clone(),
435             config: GrpcConfig {
436                 origin: self.config.origin.clone(),
437                 send_compression_encodings: self.config.send_compression_encodings,
438                 accept_compression_encodings: self.config.accept_compression_encodings,
439                 max_encoding_message_size: self.config.max_encoding_message_size,
440                 max_decoding_message_size: self.config.max_decoding_message_size,
441             },
442         }
443     }
444 }
445 
446 impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result447     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
448         let mut f = f.debug_struct("Grpc");
449 
450         f.field("inner", &self.inner);
451 
452         f.field("origin", &self.config.origin);
453 
454         f.field(
455             "compression_encoding",
456             &self.config.send_compression_encodings,
457         );
458 
459         f.field(
460             "accept_compression_encodings",
461             &self.config.accept_compression_encodings,
462         );
463 
464         f.field(
465             "max_decoding_message_size",
466             &self.config.max_decoding_message_size,
467         );
468 
469         f.field(
470             "max_encoding_message_size",
471             &self.config.max_encoding_message_size,
472         );
473 
474         f.finish()
475     }
476 }
477