1 use crate::{extract::rejection::*, response::IntoResponseParts};
2 use async_trait::async_trait;
3 use axum_core::{
4     extract::FromRequestParts,
5     response::{IntoResponse, Response, ResponseParts},
6 };
7 use http::{request::Parts, Request};
8 use std::{
9     convert::Infallible,
10     task::{Context, Poll},
11 };
12 use tower_service::Service;
13 
14 /// Extractor and response for extensions.
15 ///
16 /// # As extractor
17 ///
18 /// This is commonly used to share state across handlers.
19 ///
20 /// ```rust,no_run
21 /// use axum::{
22 ///     Router,
23 ///     Extension,
24 ///     routing::get,
25 /// };
26 /// use std::sync::Arc;
27 ///
28 /// // Some shared state used throughout our application
29 /// struct State {
30 ///     // ...
31 /// }
32 ///
33 /// async fn handler(state: Extension<Arc<State>>) {
34 ///     // ...
35 /// }
36 ///
37 /// let state = Arc::new(State { /* ... */ });
38 ///
39 /// let app = Router::new().route("/", get(handler))
40 ///     // Add middleware that inserts the state into all incoming request's
41 ///     // extensions.
42 ///     .layer(Extension(state));
43 /// # async {
44 /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
45 /// # };
46 /// ```
47 ///
48 /// If the extension is missing it will reject the request with a `500 Internal
49 /// Server Error` response.
50 ///
51 /// # As response
52 ///
53 /// Response extensions can be used to share state with middleware.
54 ///
55 /// ```rust
56 /// use axum::{
57 ///     Extension,
58 ///     response::IntoResponse,
59 /// };
60 ///
61 /// async fn handler() -> (Extension<Foo>, &'static str) {
62 ///     (
63 ///         Extension(Foo("foo")),
64 ///         "Hello, World!"
65 ///     )
66 /// }
67 ///
68 /// #[derive(Clone)]
69 /// struct Foo(&'static str);
70 /// ```
71 #[derive(Debug, Clone, Copy, Default)]
72 #[must_use]
73 pub struct Extension<T>(pub T);
74 
75 #[async_trait]
76 impl<T, S> FromRequestParts<S> for Extension<T>
77 where
78     T: Clone + Send + Sync + 'static,
79     S: Send + Sync,
80 {
81     type Rejection = ExtensionRejection;
82 
from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection>83     async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
84         let value = req
85             .extensions
86             .get::<T>()
87             .ok_or_else(|| {
88                 MissingExtension::from_err(format!(
89                     "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.",
90                     std::any::type_name::<T>()
91                 ))
92             })
93             .map(|x| x.clone())?;
94 
95         Ok(Extension(value))
96     }
97 }
98 
99 axum_core::__impl_deref!(Extension);
100 
101 impl<T> IntoResponseParts for Extension<T>
102 where
103     T: Send + Sync + 'static,
104 {
105     type Error = Infallible;
106 
into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error>107     fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
108         res.extensions_mut().insert(self.0);
109         Ok(res)
110     }
111 }
112 
113 impl<T> IntoResponse for Extension<T>
114 where
115     T: Send + Sync + 'static,
116 {
into_response(self) -> Response117     fn into_response(self) -> Response {
118         let mut res = ().into_response();
119         res.extensions_mut().insert(self.0);
120         res
121     }
122 }
123 
124 impl<S, T> tower_layer::Layer<S> for Extension<T>
125 where
126     T: Clone + Send + Sync + 'static,
127 {
128     type Service = AddExtension<S, T>;
129 
layer(&self, inner: S) -> Self::Service130     fn layer(&self, inner: S) -> Self::Service {
131         AddExtension {
132             inner,
133             value: self.0.clone(),
134         }
135     }
136 }
137 
138 /// Middleware for adding some shareable value to [request extensions].
139 ///
140 /// See [Sharing state with handlers](index.html#sharing-state-with-handlers)
141 /// for more details.
142 ///
143 /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
144 #[derive(Clone, Copy, Debug)]
145 pub struct AddExtension<S, T> {
146     pub(crate) inner: S,
147     pub(crate) value: T,
148 }
149 
150 impl<ResBody, S, T> Service<Request<ResBody>> for AddExtension<S, T>
151 where
152     S: Service<Request<ResBody>>,
153     T: Clone + Send + Sync + 'static,
154 {
155     type Response = S::Response;
156     type Error = S::Error;
157     type Future = S::Future;
158 
159     #[inline]
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>160     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161         self.inner.poll_ready(cx)
162     }
163 
call(&mut self, mut req: Request<ResBody>) -> Self::Future164     fn call(&mut self, mut req: Request<ResBody>) -> Self::Future {
165         req.extensions_mut().insert(self.value.clone());
166         self.inner.call(req)
167     }
168 }
169