summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock1
-rw-r--r--Cargo.toml1
-rw-r--r--src/cli.rs9
-rw-r--r--src/setup/middleware.rs20
-rw-r--r--src/setup/mod.rs4
-rw-r--r--src/setup/required.rs107
-rw-r--r--src/ui/middleware.rs15
-rw-r--r--src/ui/mod.rs1
-rw-r--r--src/ui/routes/mod.rs6
9 files changed, 116 insertions, 48 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 65cacbd..1e6cfd8 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1430,6 +1430,7 @@ dependencies = [
"thiserror",
"tokio",
"tokio-stream",
+ "tower",
"unicode-casefold",
"unicode-normalization",
"unicode-segmentation",
diff --git a/Cargo.toml b/Cargo.toml
index 82693e3..fe5c90b 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -47,6 +47,7 @@ sqlx = { version = "=0.8.3", features = ["chrono", "runtime-tokio", "sqlite"] }
thiserror = "2.0.11"
tokio = { version = "1.43.0", features = ["rt", "macros", "rt-multi-thread"] }
tokio-stream = { version = "0.1.17", features = ["sync"] }
+tower = "0.5.2"
unicode-casefold = "0.2.0"
unicode-normalization = "0.1.24"
unicode-segmentation = "1.12.0"
diff --git a/src/cli.rs b/src/cli.rs
index 4232c00..7bfdbc0 100644
--- a/src/cli.rs
+++ b/src/cli.rs
@@ -15,12 +15,7 @@ use clap::{CommandFactory, Parser};
use sqlx::sqlite::SqlitePool;
use tokio::net;
-use crate::{
- app::App,
- boot, channel, clock, db, event, expire, invite, message,
- setup::{self, middleware::setup_required},
- ui, user,
-};
+use crate::{app::App, boot, channel, clock, db, event, expire, invite, message, setup, ui, user};
/// Command-line entry point for running the `pilcrow` server.
///
@@ -152,7 +147,7 @@ fn routers(app: &App) -> Router<App> {
app.clone(),
expire::middleware,
))
- .route_layer(middleware::from_fn_with_state(app.clone(), setup_required)),
+ .route_layer(setup::Required(app.clone())),
// API endpoints that handle setup
setup::router(),
// The UI (handles setup state itself)
diff --git a/src/setup/middleware.rs b/src/setup/middleware.rs
deleted file mode 100644
index 5f9996b..0000000
--- a/src/setup/middleware.rs
+++ /dev/null
@@ -1,20 +0,0 @@
-use axum::{
- extract::{Request, State},
- http::StatusCode,
- middleware::Next,
- response::{IntoResponse, Response},
-};
-
-use crate::{app::App, error::Internal};
-
-pub async fn setup_required(State(app): State<App>, request: Request, next: Next) -> Response {
- match app.setup().completed().await {
- Ok(true) => next.run(request).await,
- Ok(false) => (
- StatusCode::SERVICE_UNAVAILABLE,
- "initial setup not completed",
- )
- .into_response(),
- Err(error) => Internal::from(error).into_response(),
- }
-}
diff --git a/src/setup/mod.rs b/src/setup/mod.rs
index 5a8fa37..a4b821c 100644
--- a/src/setup/mod.rs
+++ b/src/setup/mod.rs
@@ -1,6 +1,6 @@
pub mod app;
-pub mod middleware;
pub mod repo;
+mod required;
mod routes;
-pub use self::routes::router;
+pub use self::{required::Required, routes::router};
diff --git a/src/setup/required.rs b/src/setup/required.rs
new file mode 100644
index 0000000..2112e4b
--- /dev/null
+++ b/src/setup/required.rs
@@ -0,0 +1,107 @@
+use axum::{
+ extract::Request,
+ http::StatusCode,
+ response::{IntoResponse, Response},
+};
+use std::pin::Pin;
+use std::task::{Context, Poll};
+use tower::{Layer, Service};
+
+use crate::{app::App, error::Internal};
+
+#[derive(Clone)]
+pub struct Required(pub App);
+
+impl Required {
+ pub fn with_fallback<F>(self, fallback: F) -> WithFallback<F> {
+ let Self(app) = self;
+ WithFallback { app, fallback }
+ }
+}
+
+impl<S> Layer<S> for Required {
+ type Service = Middleware<S, Unavailable>;
+
+ fn layer(&self, inner: S) -> Self::Service {
+ let Self(app) = self.clone();
+ Middleware {
+ inner,
+ app,
+ fallback: Unavailable,
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct WithFallback<F> {
+ app: App,
+ fallback: F,
+}
+
+impl<S, F> Layer<S> for WithFallback<F>
+where
+ Self: Clone,
+{
+ type Service = Middleware<S, F>;
+
+ fn layer(&self, inner: S) -> Self::Service {
+ let Self { app, fallback } = self.clone();
+ Middleware {
+ inner,
+ app,
+ fallback,
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct Middleware<S, F> {
+ inner: S,
+ app: App,
+ fallback: F,
+}
+
+impl<S, F> Service<Request> for Middleware<S, F>
+where
+ Self: Clone,
+ S: Service<Request, Response = Response> + Send + 'static,
+ S::Future: Send,
+ F: IntoResponse + Clone + Send + 'static,
+{
+ type Response = S::Response;
+ type Error = S::Error;
+ type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
+
+ fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ self.inner.poll_ready(ctx)
+ }
+
+ fn call(&mut self, req: Request) -> Self::Future {
+ let Self {
+ mut inner,
+ app,
+ fallback,
+ } = self.clone();
+
+ Box::pin(async move {
+ match app.setup().completed().await {
+ Ok(true) => inner.call(req).await,
+ Ok(false) => Ok(fallback.into_response()),
+ Err(error) => Ok(Internal::from(error).into_response()),
+ }
+ })
+ }
+}
+
+#[derive(Clone)]
+pub struct Unavailable;
+
+impl IntoResponse for Unavailable {
+ fn into_response(self) -> Response {
+ (
+ StatusCode::SERVICE_UNAVAILABLE,
+ "initial setup not completed",
+ )
+ .into_response()
+ }
+}
diff --git a/src/ui/middleware.rs b/src/ui/middleware.rs
deleted file mode 100644
index f60ee1c..0000000
--- a/src/ui/middleware.rs
+++ /dev/null
@@ -1,15 +0,0 @@
-use axum::{
- extract::{Request, State},
- middleware::Next,
- response::{IntoResponse, Redirect, Response},
-};
-
-use crate::{app::App, error::Internal};
-
-pub async fn setup_required(State(app): State<App>, request: Request, next: Next) -> Response {
- match app.setup().completed().await {
- Ok(true) => next.run(request).await,
- Ok(false) => Redirect::to("/setup").into_response(),
- Err(error) => Internal::from(error).into_response(),
- }
-}
diff --git a/src/ui/mod.rs b/src/ui/mod.rs
index f8caa48..e834bba 100644
--- a/src/ui/mod.rs
+++ b/src/ui/mod.rs
@@ -1,6 +1,5 @@
mod assets;
mod error;
-mod middleware;
mod mime;
mod routes;
diff --git a/src/ui/routes/mod.rs b/src/ui/routes/mod.rs
index 80dc1e5..dc94773 100644
--- a/src/ui/routes/mod.rs
+++ b/src/ui/routes/mod.rs
@@ -1,6 +1,6 @@
-use axum::{Router, middleware, routing::get};
+use axum::{Router, response::Redirect, routing::get};
-use crate::{app::App, ui::middleware::setup_required};
+use crate::app::App;
mod ch;
mod get;
@@ -21,7 +21,7 @@ pub fn router(app: &App) -> Router<App> {
.route("/login", get(login::get::handler))
.route("/ch/{channel}", get(ch::channel::get::handler))
.route("/invite/{invite}", get(invite::invite::get::handler))
- .route_layer(middleware::from_fn_with_state(app.clone(), setup_required)),
+ .route_layer(crate::setup::Required(app.clone()).with_fallback(Redirect::to("/setup"))),
]
.into_iter()
.fold(Router::default(), Router::merge)