1 use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings}; 2 use crate::{ 3 body::BoxBody, 4 client::GrpcService, 5 codec::{encode_client, Codec, Decoder, Streaming}, 6 request::SanitizeHeaders, 7 Code, Request, Response, Status, 8 }; 9 use http::{ 10 header::{HeaderValue, CONTENT_TYPE, TE}, 11 uri::{PathAndQuery, Uri}, 12 }; 13 use http_body::Body; 14 use std::{fmt, future}; 15 use tokio_stream::{Stream, StreamExt}; 16 17 /// A gRPC client dispatcher. 18 /// 19 /// This will wrap some inner [`GrpcService`] and will encode/decode 20 /// messages via the provided codec. 21 /// 22 /// Each request method takes a [`Request`], a [`PathAndQuery`], and a 23 /// [`Codec`]. The request contains the message to send via the 24 /// [`Codec::encoder`]. The path determines the fully qualified path 25 /// that will be append to the outgoing uri. The path must follow 26 /// the conventions explained in the [gRPC protocol definition] under `Path →`. An 27 /// example of this path could look like `/greeter.Greeter/SayHello`. 28 /// 29 /// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests 30 pub struct Grpc<T> { 31 inner: T, 32 config: GrpcConfig, 33 } 34 35 struct GrpcConfig { 36 origin: Uri, 37 /// Which compression encodings does the client accept? 38 accept_compression_encodings: EnabledCompressionEncodings, 39 /// The compression encoding that will be applied to requests. 40 send_compression_encodings: Option<CompressionEncoding>, 41 /// Limits the maximum size of a decoded message. 42 max_decoding_message_size: Option<usize>, 43 /// Limits the maximum size of an encoded message. 44 max_encoding_message_size: Option<usize>, 45 } 46 47 impl<T> Grpc<T> { 48 /// Creates a new gRPC client with the provided [`GrpcService`]. new(inner: T) -> Self49 pub fn new(inner: T) -> Self { 50 Self::with_origin(inner, Uri::default()) 51 } 52 53 /// Creates a new gRPC client with the provided [`GrpcService`] and `Uri`. 54 /// 55 /// The provided Uri will use only the scheme and authority parts as the 56 /// path_and_query portion will be set for each method. with_origin(inner: T, origin: Uri) -> Self57 pub fn with_origin(inner: T, origin: Uri) -> Self { 58 Self { 59 inner, 60 config: GrpcConfig { 61 origin, 62 send_compression_encodings: None, 63 accept_compression_encodings: EnabledCompressionEncodings::default(), 64 max_decoding_message_size: None, 65 max_encoding_message_size: None, 66 }, 67 } 68 } 69 70 /// Compress requests with the provided encoding. 71 /// 72 /// Requires the server to accept the specified encoding, otherwise it might return an error. 73 /// 74 /// # Example 75 /// 76 /// The most common way of using this is through a client generated by tonic-build: 77 /// 78 /// ```rust 79 /// use tonic::transport::Channel; 80 /// # enum CompressionEncoding { Gzip } 81 /// # struct TestClient<T>(T); 82 /// # impl<T> TestClient<T> { 83 /// # fn new(channel: T) -> Self { Self(channel) } 84 /// # fn send_compressed(self, _: CompressionEncoding) -> Self { self } 85 /// # } 86 /// 87 /// # async { 88 /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) 89 /// .connect() 90 /// .await 91 /// .unwrap(); 92 /// 93 /// let client = TestClient::new(channel).send_compressed(CompressionEncoding::Gzip); 94 /// # }; 95 /// ``` send_compressed(mut self, encoding: CompressionEncoding) -> Self96 pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { 97 self.config.send_compression_encodings = Some(encoding); 98 self 99 } 100 101 /// Enable accepting compressed responses. 102 /// 103 /// Requires the server to also support sending compressed responses. 104 /// 105 /// # Example 106 /// 107 /// The most common way of using this is through a client generated by tonic-build: 108 /// 109 /// ```rust 110 /// use tonic::transport::Channel; 111 /// # enum CompressionEncoding { Gzip } 112 /// # struct TestClient<T>(T); 113 /// # impl<T> TestClient<T> { 114 /// # fn new(channel: T) -> Self { Self(channel) } 115 /// # fn accept_compressed(self, _: CompressionEncoding) -> Self { self } 116 /// # } 117 /// 118 /// # async { 119 /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) 120 /// .connect() 121 /// .await 122 /// .unwrap(); 123 /// 124 /// let client = TestClient::new(channel).accept_compressed(CompressionEncoding::Gzip); 125 /// # }; 126 /// ``` accept_compressed(mut self, encoding: CompressionEncoding) -> Self127 pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { 128 self.config.accept_compression_encodings.enable(encoding); 129 self 130 } 131 132 /// Limits the maximum size of a decoded message. 133 /// 134 /// # Example 135 /// 136 /// The most common way of using this is through a client generated by tonic-build: 137 /// 138 /// ```rust 139 /// use tonic::transport::Channel; 140 /// # struct TestClient<T>(T); 141 /// # impl<T> TestClient<T> { 142 /// # fn new(channel: T) -> Self { Self(channel) } 143 /// # fn max_decoding_message_size(self, _: usize) -> Self { self } 144 /// # } 145 /// 146 /// # async { 147 /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) 148 /// .connect() 149 /// .await 150 /// .unwrap(); 151 /// 152 /// // Set the limit to 2MB, Defaults to 4MB. 153 /// let limit = 2 * 1024 * 1024; 154 /// let client = TestClient::new(channel).max_decoding_message_size(limit); 155 /// # }; 156 /// ``` max_decoding_message_size(mut self, limit: usize) -> Self157 pub fn max_decoding_message_size(mut self, limit: usize) -> Self { 158 self.config.max_decoding_message_size = Some(limit); 159 self 160 } 161 162 /// Limits the maximum size of an ecoded message. 163 /// 164 /// # Example 165 /// 166 /// The most common way of using this is through a client generated by tonic-build: 167 /// 168 /// ```rust 169 /// use tonic::transport::Channel; 170 /// # struct TestClient<T>(T); 171 /// # impl<T> TestClient<T> { 172 /// # fn new(channel: T) -> Self { Self(channel) } 173 /// # fn max_encoding_message_size(self, _: usize) -> Self { self } 174 /// # } 175 /// 176 /// # async { 177 /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) 178 /// .connect() 179 /// .await 180 /// .unwrap(); 181 /// 182 /// // Set the limit to 2MB, Defaults to 4MB. 183 /// let limit = 2 * 1024 * 1024; 184 /// let client = TestClient::new(channel).max_encoding_message_size(limit); 185 /// # }; 186 /// ``` max_encoding_message_size(mut self, limit: usize) -> Self187 pub fn max_encoding_message_size(mut self, limit: usize) -> Self { 188 self.config.max_encoding_message_size = Some(limit); 189 self 190 } 191 192 /// Check if the inner [`GrpcService`] is able to accept a new request. 193 /// 194 /// This will call [`GrpcService::poll_ready`] until it returns ready or 195 /// an error. If this returns ready the inner [`GrpcService`] is ready to 196 /// accept one more request. ready(&mut self) -> Result<(), T::Error> where T: GrpcService<BoxBody>,197 pub async fn ready(&mut self) -> Result<(), T::Error> 198 where 199 T: GrpcService<BoxBody>, 200 { 201 future::poll_fn(|cx| self.inner.poll_ready(cx)).await 202 } 203 204 /// Send a single unary gRPC request. unary<M1, M2, C>( &mut self, request: Request<M1>, path: PathAndQuery, codec: C, ) -> Result<Response<M2>, Status> where T: GrpcService<BoxBody>, T::ResponseBody: Body + Send + 'static, <T::ResponseBody as Body>::Error: Into<crate::Error>, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,205 pub async fn unary<M1, M2, C>( 206 &mut self, 207 request: Request<M1>, 208 path: PathAndQuery, 209 codec: C, 210 ) -> Result<Response<M2>, Status> 211 where 212 T: GrpcService<BoxBody>, 213 T::ResponseBody: Body + Send + 'static, 214 <T::ResponseBody as Body>::Error: Into<crate::Error>, 215 C: Codec<Encode = M1, Decode = M2>, 216 M1: Send + Sync + 'static, 217 M2: Send + Sync + 'static, 218 { 219 let request = request.map(|m| tokio_stream::once(m)); 220 self.client_streaming(request, path, codec).await 221 } 222 223 /// Send a client side streaming gRPC request. client_streaming<S, M1, M2, C>( &mut self, request: Request<S>, path: PathAndQuery, codec: C, ) -> Result<Response<M2>, Status> where T: GrpcService<BoxBody>, T::ResponseBody: Body + Send + 'static, <T::ResponseBody as Body>::Error: Into<crate::Error>, S: Stream<Item = M1> + Send + 'static, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,224 pub async fn client_streaming<S, M1, M2, C>( 225 &mut self, 226 request: Request<S>, 227 path: PathAndQuery, 228 codec: C, 229 ) -> Result<Response<M2>, Status> 230 where 231 T: GrpcService<BoxBody>, 232 T::ResponseBody: Body + Send + 'static, 233 <T::ResponseBody as Body>::Error: Into<crate::Error>, 234 S: Stream<Item = M1> + Send + 'static, 235 C: Codec<Encode = M1, Decode = M2>, 236 M1: Send + Sync + 'static, 237 M2: Send + Sync + 'static, 238 { 239 let (mut parts, body, extensions) = 240 self.streaming(request, path, codec).await?.into_parts(); 241 242 tokio::pin!(body); 243 244 let message = body 245 .try_next() 246 .await 247 .map_err(|mut status| { 248 status.metadata_mut().merge(parts.clone()); 249 status 250 })? 251 .ok_or_else(|| Status::new(Code::Internal, "Missing response message."))?; 252 253 if let Some(trailers) = body.trailers().await? { 254 parts.merge(trailers); 255 } 256 257 Ok(Response::from_parts(parts, message, extensions)) 258 } 259 260 /// Send a server side streaming gRPC request. server_streaming<M1, M2, C>( &mut self, request: Request<M1>, path: PathAndQuery, codec: C, ) -> Result<Response<Streaming<M2>>, Status> where T: GrpcService<BoxBody>, T::ResponseBody: Body + Send + 'static, <T::ResponseBody as Body>::Error: Into<crate::Error>, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,261 pub async fn server_streaming<M1, M2, C>( 262 &mut self, 263 request: Request<M1>, 264 path: PathAndQuery, 265 codec: C, 266 ) -> Result<Response<Streaming<M2>>, Status> 267 where 268 T: GrpcService<BoxBody>, 269 T::ResponseBody: Body + Send + 'static, 270 <T::ResponseBody as Body>::Error: Into<crate::Error>, 271 C: Codec<Encode = M1, Decode = M2>, 272 M1: Send + Sync + 'static, 273 M2: Send + Sync + 'static, 274 { 275 let request = request.map(|m| tokio_stream::once(m)); 276 self.streaming(request, path, codec).await 277 } 278 279 /// Send a bi-directional streaming gRPC request. streaming<S, M1, M2, C>( &mut self, request: Request<S>, path: PathAndQuery, mut codec: C, ) -> Result<Response<Streaming<M2>>, Status> where T: GrpcService<BoxBody>, T::ResponseBody: Body + Send + 'static, <T::ResponseBody as Body>::Error: Into<crate::Error>, S: Stream<Item = M1> + Send + 'static, C: Codec<Encode = M1, Decode = M2>, M1: Send + Sync + 'static, M2: Send + Sync + 'static,280 pub async fn streaming<S, M1, M2, C>( 281 &mut self, 282 request: Request<S>, 283 path: PathAndQuery, 284 mut codec: C, 285 ) -> Result<Response<Streaming<M2>>, Status> 286 where 287 T: GrpcService<BoxBody>, 288 T::ResponseBody: Body + Send + 'static, 289 <T::ResponseBody as Body>::Error: Into<crate::Error>, 290 S: Stream<Item = M1> + Send + 'static, 291 C: Codec<Encode = M1, Decode = M2>, 292 M1: Send + Sync + 'static, 293 M2: Send + Sync + 'static, 294 { 295 let request = request 296 .map(|s| { 297 encode_client( 298 codec.encoder(), 299 s, 300 self.config.send_compression_encodings, 301 self.config.max_encoding_message_size, 302 ) 303 }) 304 .map(BoxBody::new); 305 306 let request = self.config.prepare_request(request, path); 307 308 let response = self 309 .inner 310 .call(request) 311 .await 312 .map_err(Status::from_error_generic)?; 313 314 let decoder = codec.decoder(); 315 316 self.create_response(decoder, response) 317 } 318 319 // Keeping this code in a separate function from Self::streaming lets functions that return the 320 // same output share the generated binary code create_response<M2>( &self, decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static, response: http::Response<T::ResponseBody>, ) -> Result<Response<Streaming<M2>>, Status> where T: GrpcService<BoxBody>, T::ResponseBody: Body + Send + 'static, <T::ResponseBody as Body>::Error: Into<crate::Error>,321 fn create_response<M2>( 322 &self, 323 decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static, 324 response: http::Response<T::ResponseBody>, 325 ) -> Result<Response<Streaming<M2>>, Status> 326 where 327 T: GrpcService<BoxBody>, 328 T::ResponseBody: Body + Send + 'static, 329 <T::ResponseBody as Body>::Error: Into<crate::Error>, 330 { 331 let encoding = CompressionEncoding::from_encoding_header( 332 response.headers(), 333 self.config.accept_compression_encodings, 334 )?; 335 336 let status_code = response.status(); 337 let trailers_only_status = Status::from_header_map(response.headers()); 338 339 // We do not need to check for trailers if the `grpc-status` header is present 340 // with a valid code. 341 let expect_additional_trailers = if let Some(status) = trailers_only_status { 342 if status.code() != Code::Ok { 343 return Err(status); 344 } 345 346 false 347 } else { 348 true 349 }; 350 351 let response = response.map(|body| { 352 if expect_additional_trailers { 353 Streaming::new_response( 354 decoder, 355 body, 356 status_code, 357 encoding, 358 self.config.max_decoding_message_size, 359 ) 360 } else { 361 Streaming::new_empty(decoder, body) 362 } 363 }); 364 365 Ok(Response::from_http(response)) 366 } 367 } 368 369 impl GrpcConfig { prepare_request( &self, request: Request<BoxBody>, path: PathAndQuery, ) -> http::Request<BoxBody>370 fn prepare_request( 371 &self, 372 request: Request<BoxBody>, 373 path: PathAndQuery, 374 ) -> http::Request<BoxBody> { 375 let mut parts = self.origin.clone().into_parts(); 376 377 match &parts.path_and_query { 378 Some(pnq) if pnq != "/" => { 379 parts.path_and_query = Some( 380 format!("{}{}", pnq.path(), path) 381 .parse() 382 .expect("must form valid path_and_query"), 383 ) 384 } 385 _ => { 386 parts.path_and_query = Some(path); 387 } 388 } 389 390 let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri"); 391 392 let mut request = request.into_http( 393 uri, 394 http::Method::POST, 395 http::Version::HTTP_2, 396 SanitizeHeaders::Yes, 397 ); 398 399 // Add the gRPC related HTTP headers 400 request 401 .headers_mut() 402 .insert(TE, HeaderValue::from_static("trailers")); 403 404 // Set the content type 405 request 406 .headers_mut() 407 .insert(CONTENT_TYPE, HeaderValue::from_static("application/grpc")); 408 409 #[cfg(any(feature = "gzip", feature = "zstd"))] 410 if let Some(encoding) = self.send_compression_encodings { 411 request.headers_mut().insert( 412 crate::codec::compression::ENCODING_HEADER, 413 encoding.into_header_value(), 414 ); 415 } 416 417 if let Some(header_value) = self 418 .accept_compression_encodings 419 .into_accept_encoding_header_value() 420 { 421 request.headers_mut().insert( 422 crate::codec::compression::ACCEPT_ENCODING_HEADER, 423 header_value, 424 ); 425 } 426 427 request 428 } 429 } 430 431 impl<T: Clone> Clone for Grpc<T> { clone(&self) -> Self432 fn clone(&self) -> Self { 433 Self { 434 inner: self.inner.clone(), 435 config: GrpcConfig { 436 origin: self.config.origin.clone(), 437 send_compression_encodings: self.config.send_compression_encodings, 438 accept_compression_encodings: self.config.accept_compression_encodings, 439 max_encoding_message_size: self.config.max_encoding_message_size, 440 max_decoding_message_size: self.config.max_decoding_message_size, 441 }, 442 } 443 } 444 } 445 446 impl<T: fmt::Debug> fmt::Debug for Grpc<T> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result447 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 448 let mut f = f.debug_struct("Grpc"); 449 450 f.field("inner", &self.inner); 451 452 f.field("origin", &self.config.origin); 453 454 f.field( 455 "compression_encoding", 456 &self.config.send_compression_encodings, 457 ); 458 459 f.field( 460 "accept_compression_encodings", 461 &self.config.accept_compression_encodings, 462 ); 463 464 f.field( 465 "max_decoding_message_size", 466 &self.config.max_decoding_message_size, 467 ); 468 469 f.field( 470 "max_encoding_message_size", 471 &self.config.max_encoding_message_size, 472 ); 473 474 f.finish() 475 } 476 } 477