use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings}; use crate::{ body::BoxBody, client::GrpcService, codec::{encode_client, Codec, Decoder, Streaming}, request::SanitizeHeaders, Code, Request, Response, Status, }; use http::{ header::{HeaderValue, CONTENT_TYPE, TE}, uri::{PathAndQuery, Uri}, }; use http_body::Body; use std::{fmt, future}; use tokio_stream::{Stream, StreamExt}; /// A gRPC client dispatcher. /// /// This will wrap some inner [`GrpcService`] and will encode/decode /// messages via the provided codec. /// /// Each request method takes a [`Request`], a [`PathAndQuery`], and a /// [`Codec`]. The request contains the message to send via the /// [`Codec::encoder`]. The path determines the fully qualified path /// that will be append to the outgoing uri. The path must follow /// the conventions explained in the [gRPC protocol definition] under `Path →`. An /// example of this path could look like `/greeter.Greeter/SayHello`. /// /// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests pub struct Grpc { inner: T, config: GrpcConfig, } struct GrpcConfig { origin: Uri, /// Which compression encodings does the client accept? accept_compression_encodings: EnabledCompressionEncodings, /// The compression encoding that will be applied to requests. send_compression_encodings: Option, /// Limits the maximum size of a decoded message. max_decoding_message_size: Option, /// Limits the maximum size of an encoded message. max_encoding_message_size: Option, } impl Grpc { /// Creates a new gRPC client with the provided [`GrpcService`]. pub fn new(inner: T) -> Self { Self::with_origin(inner, Uri::default()) } /// Creates a new gRPC client with the provided [`GrpcService`] and `Uri`. /// /// The provided Uri will use only the scheme and authority parts as the /// path_and_query portion will be set for each method. pub fn with_origin(inner: T, origin: Uri) -> Self { Self { inner, config: GrpcConfig { origin, send_compression_encodings: None, accept_compression_encodings: EnabledCompressionEncodings::default(), max_decoding_message_size: None, max_encoding_message_size: None, }, } } /// Compress requests with the provided encoding. /// /// Requires the server to accept the specified encoding, otherwise it might return an error. /// /// # Example /// /// The most common way of using this is through a client generated by tonic-build: /// /// ```rust /// use tonic::transport::Channel; /// # enum CompressionEncoding { Gzip } /// # struct TestClient(T); /// # impl TestClient { /// # fn new(channel: T) -> Self { Self(channel) } /// # fn send_compressed(self, _: CompressionEncoding) -> Self { self } /// # } /// /// # async { /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) /// .connect() /// .await /// .unwrap(); /// /// let client = TestClient::new(channel).send_compressed(CompressionEncoding::Gzip); /// # }; /// ``` pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { self.config.send_compression_encodings = Some(encoding); self } /// Enable accepting compressed responses. /// /// Requires the server to also support sending compressed responses. /// /// # Example /// /// The most common way of using this is through a client generated by tonic-build: /// /// ```rust /// use tonic::transport::Channel; /// # enum CompressionEncoding { Gzip } /// # struct TestClient(T); /// # impl TestClient { /// # fn new(channel: T) -> Self { Self(channel) } /// # fn accept_compressed(self, _: CompressionEncoding) -> Self { self } /// # } /// /// # async { /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) /// .connect() /// .await /// .unwrap(); /// /// let client = TestClient::new(channel).accept_compressed(CompressionEncoding::Gzip); /// # }; /// ``` pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { self.config.accept_compression_encodings.enable(encoding); self } /// Limits the maximum size of a decoded message. /// /// # Example /// /// The most common way of using this is through a client generated by tonic-build: /// /// ```rust /// use tonic::transport::Channel; /// # struct TestClient(T); /// # impl TestClient { /// # fn new(channel: T) -> Self { Self(channel) } /// # fn max_decoding_message_size(self, _: usize) -> Self { self } /// # } /// /// # async { /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) /// .connect() /// .await /// .unwrap(); /// /// // Set the limit to 2MB, Defaults to 4MB. /// let limit = 2 * 1024 * 1024; /// let client = TestClient::new(channel).max_decoding_message_size(limit); /// # }; /// ``` pub fn max_decoding_message_size(mut self, limit: usize) -> Self { self.config.max_decoding_message_size = Some(limit); self } /// Limits the maximum size of an ecoded message. /// /// # Example /// /// The most common way of using this is through a client generated by tonic-build: /// /// ```rust /// use tonic::transport::Channel; /// # struct TestClient(T); /// # impl TestClient { /// # fn new(channel: T) -> Self { Self(channel) } /// # fn max_encoding_message_size(self, _: usize) -> Self { self } /// # } /// /// # async { /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) /// .connect() /// .await /// .unwrap(); /// /// // Set the limit to 2MB, Defaults to 4MB. /// let limit = 2 * 1024 * 1024; /// let client = TestClient::new(channel).max_encoding_message_size(limit); /// # }; /// ``` pub fn max_encoding_message_size(mut self, limit: usize) -> Self { self.config.max_encoding_message_size = Some(limit); self } /// Check if the inner [`GrpcService`] is able to accept a new request. /// /// This will call [`GrpcService::poll_ready`] until it returns ready or /// an error. If this returns ready the inner [`GrpcService`] is ready to /// accept one more request. pub async fn ready(&mut self) -> Result<(), T::Error> where T: GrpcService, { future::poll_fn(|cx| self.inner.poll_ready(cx)).await } /// Send a single unary gRPC request. pub async fn unary( &mut self, request: Request, path: PathAndQuery, codec: C, ) -> Result, Status> where T: GrpcService, T::ResponseBody: Body + Send + 'static, ::Error: Into, C: Codec, M1: Send + Sync + 'static, M2: Send + Sync + 'static, { let request = request.map(|m| tokio_stream::once(m)); self.client_streaming(request, path, codec).await } /// Send a client side streaming gRPC request. pub async fn client_streaming( &mut self, request: Request, path: PathAndQuery, codec: C, ) -> Result, Status> where T: GrpcService, T::ResponseBody: Body + Send + 'static, ::Error: Into, S: Stream + Send + 'static, C: Codec, M1: Send + Sync + 'static, M2: Send + Sync + 'static, { let (mut parts, body, extensions) = self.streaming(request, path, codec).await?.into_parts(); tokio::pin!(body); let message = body .try_next() .await .map_err(|mut status| { status.metadata_mut().merge(parts.clone()); status })? .ok_or_else(|| Status::new(Code::Internal, "Missing response message."))?; if let Some(trailers) = body.trailers().await? { parts.merge(trailers); } Ok(Response::from_parts(parts, message, extensions)) } /// Send a server side streaming gRPC request. pub async fn server_streaming( &mut self, request: Request, path: PathAndQuery, codec: C, ) -> Result>, Status> where T: GrpcService, T::ResponseBody: Body + Send + 'static, ::Error: Into, C: Codec, M1: Send + Sync + 'static, M2: Send + Sync + 'static, { let request = request.map(|m| tokio_stream::once(m)); self.streaming(request, path, codec).await } /// Send a bi-directional streaming gRPC request. pub async fn streaming( &mut self, request: Request, path: PathAndQuery, mut codec: C, ) -> Result>, Status> where T: GrpcService, T::ResponseBody: Body + Send + 'static, ::Error: Into, S: Stream + Send + 'static, C: Codec, M1: Send + Sync + 'static, M2: Send + Sync + 'static, { let request = request .map(|s| { encode_client( codec.encoder(), s, self.config.send_compression_encodings, self.config.max_encoding_message_size, ) }) .map(BoxBody::new); let request = self.config.prepare_request(request, path); let response = self .inner .call(request) .await .map_err(Status::from_error_generic)?; let decoder = codec.decoder(); self.create_response(decoder, response) } // Keeping this code in a separate function from Self::streaming lets functions that return the // same output share the generated binary code fn create_response( &self, decoder: impl Decoder + Send + 'static, response: http::Response, ) -> Result>, Status> where T: GrpcService, T::ResponseBody: Body + Send + 'static, ::Error: Into, { let encoding = CompressionEncoding::from_encoding_header( response.headers(), self.config.accept_compression_encodings, )?; let status_code = response.status(); let trailers_only_status = Status::from_header_map(response.headers()); // We do not need to check for trailers if the `grpc-status` header is present // with a valid code. let expect_additional_trailers = if let Some(status) = trailers_only_status { if status.code() != Code::Ok { return Err(status); } false } else { true }; let response = response.map(|body| { if expect_additional_trailers { Streaming::new_response( decoder, body, status_code, encoding, self.config.max_decoding_message_size, ) } else { Streaming::new_empty(decoder, body) } }); Ok(Response::from_http(response)) } } impl GrpcConfig { fn prepare_request( &self, request: Request, path: PathAndQuery, ) -> http::Request { let mut parts = self.origin.clone().into_parts(); match &parts.path_and_query { Some(pnq) if pnq != "/" => { parts.path_and_query = Some( format!("{}{}", pnq.path(), path) .parse() .expect("must form valid path_and_query"), ) } _ => { parts.path_and_query = Some(path); } } let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri"); let mut request = request.into_http( uri, http::Method::POST, http::Version::HTTP_2, SanitizeHeaders::Yes, ); // Add the gRPC related HTTP headers request .headers_mut() .insert(TE, HeaderValue::from_static("trailers")); // Set the content type request .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("application/grpc")); #[cfg(any(feature = "gzip", feature = "zstd"))] if let Some(encoding) = self.send_compression_encodings { request.headers_mut().insert( crate::codec::compression::ENCODING_HEADER, encoding.into_header_value(), ); } if let Some(header_value) = self .accept_compression_encodings .into_accept_encoding_header_value() { request.headers_mut().insert( crate::codec::compression::ACCEPT_ENCODING_HEADER, header_value, ); } request } } impl Clone for Grpc { fn clone(&self) -> Self { Self { inner: self.inner.clone(), config: GrpcConfig { origin: self.config.origin.clone(), send_compression_encodings: self.config.send_compression_encodings, accept_compression_encodings: self.config.accept_compression_encodings, max_encoding_message_size: self.config.max_encoding_message_size, max_decoding_message_size: self.config.max_decoding_message_size, }, } } } impl fmt::Debug for Grpc { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut f = f.debug_struct("Grpc"); f.field("inner", &self.inner); f.field("origin", &self.config.origin); f.field( "compression_encoding", &self.config.send_compression_encodings, ); f.field( "accept_compression_encodings", &self.config.accept_compression_encodings, ); f.field( "max_decoding_message_size", &self.config.max_decoding_message_size, ); f.field( "max_encoding_message_size", &self.config.max_encoding_message_size, ); f.finish() } }