1 use crate::{extract::rejection::*, response::IntoResponseParts}; 2 use async_trait::async_trait; 3 use axum_core::{ 4 extract::FromRequestParts, 5 response::{IntoResponse, Response, ResponseParts}, 6 }; 7 use http::{request::Parts, Request}; 8 use std::{ 9 convert::Infallible, 10 task::{Context, Poll}, 11 }; 12 use tower_service::Service; 13 14 /// Extractor and response for extensions. 15 /// 16 /// # As extractor 17 /// 18 /// This is commonly used to share state across handlers. 19 /// 20 /// ```rust,no_run 21 /// use axum::{ 22 /// Router, 23 /// Extension, 24 /// routing::get, 25 /// }; 26 /// use std::sync::Arc; 27 /// 28 /// // Some shared state used throughout our application 29 /// struct State { 30 /// // ... 31 /// } 32 /// 33 /// async fn handler(state: Extension<Arc<State>>) { 34 /// // ... 35 /// } 36 /// 37 /// let state = Arc::new(State { /* ... */ }); 38 /// 39 /// let app = Router::new().route("/", get(handler)) 40 /// // Add middleware that inserts the state into all incoming request's 41 /// // extensions. 42 /// .layer(Extension(state)); 43 /// # async { 44 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); 45 /// # }; 46 /// ``` 47 /// 48 /// If the extension is missing it will reject the request with a `500 Internal 49 /// Server Error` response. 50 /// 51 /// # As response 52 /// 53 /// Response extensions can be used to share state with middleware. 54 /// 55 /// ```rust 56 /// use axum::{ 57 /// Extension, 58 /// response::IntoResponse, 59 /// }; 60 /// 61 /// async fn handler() -> (Extension<Foo>, &'static str) { 62 /// ( 63 /// Extension(Foo("foo")), 64 /// "Hello, World!" 65 /// ) 66 /// } 67 /// 68 /// #[derive(Clone)] 69 /// struct Foo(&'static str); 70 /// ``` 71 #[derive(Debug, Clone, Copy, Default)] 72 #[must_use] 73 pub struct Extension<T>(pub T); 74 75 #[async_trait] 76 impl<T, S> FromRequestParts<S> for Extension<T> 77 where 78 T: Clone + Send + Sync + 'static, 79 S: Send + Sync, 80 { 81 type Rejection = ExtensionRejection; 82 from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection>83 async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { 84 let value = req 85 .extensions 86 .get::<T>() 87 .ok_or_else(|| { 88 MissingExtension::from_err(format!( 89 "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.", 90 std::any::type_name::<T>() 91 )) 92 }) 93 .map(|x| x.clone())?; 94 95 Ok(Extension(value)) 96 } 97 } 98 99 axum_core::__impl_deref!(Extension); 100 101 impl<T> IntoResponseParts for Extension<T> 102 where 103 T: Send + Sync + 'static, 104 { 105 type Error = Infallible; 106 into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error>107 fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> { 108 res.extensions_mut().insert(self.0); 109 Ok(res) 110 } 111 } 112 113 impl<T> IntoResponse for Extension<T> 114 where 115 T: Send + Sync + 'static, 116 { into_response(self) -> Response117 fn into_response(self) -> Response { 118 let mut res = ().into_response(); 119 res.extensions_mut().insert(self.0); 120 res 121 } 122 } 123 124 impl<S, T> tower_layer::Layer<S> for Extension<T> 125 where 126 T: Clone + Send + Sync + 'static, 127 { 128 type Service = AddExtension<S, T>; 129 layer(&self, inner: S) -> Self::Service130 fn layer(&self, inner: S) -> Self::Service { 131 AddExtension { 132 inner, 133 value: self.0.clone(), 134 } 135 } 136 } 137 138 /// Middleware for adding some shareable value to [request extensions]. 139 /// 140 /// See [Sharing state with handlers](index.html#sharing-state-with-handlers) 141 /// for more details. 142 /// 143 /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html 144 #[derive(Clone, Copy, Debug)] 145 pub struct AddExtension<S, T> { 146 pub(crate) inner: S, 147 pub(crate) value: T, 148 } 149 150 impl<ResBody, S, T> Service<Request<ResBody>> for AddExtension<S, T> 151 where 152 S: Service<Request<ResBody>>, 153 T: Clone + Send + Sync + 'static, 154 { 155 type Response = S::Response; 156 type Error = S::Error; 157 type Future = S::Future; 158 159 #[inline] poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>160 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 161 self.inner.poll_ready(cx) 162 } 163 call(&mut self, mut req: Request<ResBody>) -> Self::Future164 fn call(&mut self, mut req: Request<ResBody>) -> Self::Future { 165 req.extensions_mut().insert(self.value.clone()); 166 self.inner.call(req) 167 } 168 } 169