use async_trait::async_trait; use axum_core::extract::{FromRef, FromRequestParts}; use http::request::Parts; use std::{ convert::Infallible, ops::{Deref, DerefMut}, }; /// Extractor for state. /// /// See ["Accessing state in middleware"][state-from-middleware] for how to /// access state in middleware. /// /// [state-from-middleware]: crate::middleware#accessing-state-in-middleware /// /// # With `Router` /// /// ``` /// use axum::{Router, routing::get, extract::State}; /// /// // the application state /// // /// // here you can put configuration, database connection pools, or whatever /// // state you need /// // /// // see "When states need to implement `Clone`" for more details on why we need /// // `#[derive(Clone)]` here. /// #[derive(Clone)] /// struct AppState {} /// /// let state = AppState {}; /// /// // create a `Router` that holds our state /// let app = Router::new() /// .route("/", get(handler)) /// // provide the state so the router can access it /// .with_state(state); /// /// async fn handler( /// // access the state via the `State` extractor /// // extracting a state of the wrong type results in a compile error /// State(state): State, /// ) { /// // use `state`... /// } /// # let _: axum::Router = app; /// ``` /// /// Note that `State` is an extractor, so be sure to put it before any body /// extractors, see ["the order of extractors"][order-of-extractors]. /// /// [order-of-extractors]: crate::extract#the-order-of-extractors /// /// ## Combining stateful routers /// /// Multiple [`Router`]s can be combined with [`Router::nest`] or [`Router::merge`] /// When combining [`Router`]s with one of these methods, the [`Router`]s must have /// the same state type. Generally, this can be inferred automatically: /// /// ``` /// use axum::{Router, routing::get, extract::State}; /// /// #[derive(Clone)] /// struct AppState {} /// /// let state = AppState {}; /// /// // create a `Router` that will be nested within another /// let api = Router::new() /// .route("/posts", get(posts_handler)); /// /// let app = Router::new() /// .nest("/api", api) /// .with_state(state); /// /// async fn posts_handler(State(state): State) { /// // use `state`... /// } /// # let _: axum::Router = app; /// ``` /// /// However, if you are composing [`Router`]s that are defined in separate scopes, /// you may need to annotate the [`State`] type explicitly: /// /// ``` /// use axum::{Router, routing::get, extract::State}; /// /// #[derive(Clone)] /// struct AppState {} /// /// fn make_app() -> Router { /// let state = AppState {}; /// /// Router::new() /// .nest("/api", make_api()) /// .with_state(state) // the outer Router's state is inferred /// } /// /// // the inner Router must specify its state type to compose with the /// // outer router /// fn make_api() -> Router { /// Router::new() /// .route("/posts", get(posts_handler)) /// } /// /// async fn posts_handler(State(state): State) { /// // use `state`... /// } /// # let _: axum::Router = make_app(); /// ``` /// /// In short, a [`Router`]'s generic state type defaults to `()` /// (no state) unless [`Router::with_state`] is called or the value /// of the generic type is given explicitly. /// /// [`Router`]: crate::Router /// [`Router::merge`]: crate::Router::merge /// [`Router::nest`]: crate::Router::nest /// [`Router::with_state`]: crate::Router::with_state /// /// # With `MethodRouter` /// /// ``` /// use axum::{routing::get, extract::State}; /// /// #[derive(Clone)] /// struct AppState {} /// /// let state = AppState {}; /// /// let method_router_with_state = get(handler) /// // provide the state so the handler can access it /// .with_state(state); /// /// async fn handler(State(state): State) { /// // use `state`... /// } /// # async { /// # axum::Server::bind(&"".parse().unwrap()).serve(method_router_with_state.into_make_service()).await.unwrap(); /// # }; /// ``` /// /// # With `Handler` /// /// ``` /// use axum::{routing::get, handler::Handler, extract::State}; /// /// #[derive(Clone)] /// struct AppState {} /// /// let state = AppState {}; /// /// async fn handler(State(state): State) { /// // use `state`... /// } /// /// // provide the state so the handler can access it /// let handler_with_state = handler.with_state(state); /// /// # async { /// axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) /// .serve(handler_with_state.into_make_service()) /// .await /// .expect("server failed"); /// # }; /// ``` /// /// # Substates /// /// [`State`] only allows a single state type but you can use [`FromRef`] to extract "substates": /// /// ``` /// use axum::{Router, routing::get, extract::{State, FromRef}}; /// /// // the application state /// #[derive(Clone)] /// struct AppState { /// // that holds some api specific state /// api_state: ApiState, /// } /// /// // the api specific state /// #[derive(Clone)] /// struct ApiState {} /// /// // support converting an `AppState` in an `ApiState` /// impl FromRef for ApiState { /// fn from_ref(app_state: &AppState) -> ApiState { /// app_state.api_state.clone() /// } /// } /// /// let state = AppState { /// api_state: ApiState {}, /// }; /// /// let app = Router::new() /// .route("/", get(handler)) /// .route("/api/users", get(api_users)) /// .with_state(state); /// /// async fn api_users( /// // access the api specific state /// State(api_state): State, /// ) { /// } /// /// async fn handler( /// // we can still access to top level state /// State(state): State, /// ) { /// } /// # let _: axum::Router = app; /// ``` /// /// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`. /// /// # For library authors /// /// If you're writing a library that has an extractor that needs state, this is the recommended way /// to do it: /// /// ```rust /// use axum_core::extract::{FromRequestParts, FromRef}; /// use http::request::Parts; /// use async_trait::async_trait; /// use std::convert::Infallible; /// /// // the extractor your library provides /// struct MyLibraryExtractor; /// /// #[async_trait] /// impl FromRequestParts for MyLibraryExtractor /// where /// // keep `S` generic but require that it can produce a `MyLibraryState` /// // this means users will have to implement `FromRef for MyLibraryState` /// MyLibraryState: FromRef, /// S: Send + Sync, /// { /// type Rejection = Infallible; /// /// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { /// // get a `MyLibraryState` from a reference to the state /// let state = MyLibraryState::from_ref(state); /// /// // ... /// # todo!() /// } /// } /// /// // the state your library needs /// struct MyLibraryState { /// // ... /// } /// ``` /// /// # When states need to implement `Clone` /// /// Your top level state type must implement `Clone` to be extractable with `State`: /// /// ``` /// use axum::extract::State; /// /// // no substates, so to extract to `State` we must implement `Clone` for `AppState` /// #[derive(Clone)] /// struct AppState {} /// /// async fn handler(State(state): State) { /// // ... /// } /// ``` /// /// This works because of [`impl FromRef for S where S: Clone`][`FromRef`]. /// /// This is also true if you're extracting substates, unless you _never_ extract the top level /// state itself: /// /// ``` /// use axum::extract::{State, FromRef}; /// /// // we never extract `State`, just `State`. So `AppState` doesn't need to /// // implement `Clone` /// struct AppState { /// inner: InnerState, /// } /// /// #[derive(Clone)] /// struct InnerState {} /// /// impl FromRef for InnerState { /// fn from_ref(app_state: &AppState) -> InnerState { /// app_state.inner.clone() /// } /// } /// /// async fn api_users(State(inner): State) { /// // ... /// } /// ``` /// /// In general however we recommend you implement `Clone` for all your state types to avoid /// potential type errors. /// /// # Shared mutable state /// /// [As state is global within a `Router`][global] you can't directly get a mutable reference to /// the state. /// /// The most basic solution is to use an `Arc>`. Which kind of mutex you need depends on /// your use case. See [the tokio docs] for more details. /// /// Note that holding a locked `std::sync::Mutex` across `.await` points will result in `!Send` /// futures which are incompatible with axum. If you need to hold a mutex across `.await` points, /// consider using a `tokio::sync::Mutex` instead. /// /// ## Example /// /// ``` /// use axum::{Router, routing::get, extract::State}; /// use std::sync::{Arc, Mutex}; /// /// #[derive(Clone)] /// struct AppState { /// data: Arc>, /// } /// /// async fn handler(State(state): State) { /// let mut data = state.data.lock().expect("mutex was poisoned"); /// *data = "updated foo".to_owned(); /// /// // ... /// } /// /// let state = AppState { /// data: Arc::new(Mutex::new("foo".to_owned())), /// }; /// /// let app = Router::new() /// .route("/", get(handler)) /// .with_state(state); /// # let _: Router = app; /// ``` /// /// [global]: crate::Router::with_state /// [the tokio docs]: https://docs.rs/tokio/1.25.0/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use #[derive(Debug, Default, Clone, Copy)] pub struct State(pub S); #[async_trait] impl FromRequestParts for State where InnerState: FromRef, OuterState: Send + Sync, { type Rejection = Infallible; async fn from_request_parts( _parts: &mut Parts, state: &OuterState, ) -> Result { let inner_state = InnerState::from_ref(state); Ok(Self(inner_state)) } } impl Deref for State { type Target = S; fn deref(&self) -> &Self::Target { &self.0 } } impl DerefMut for State { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } }