summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOwen Jacobson <owen@grimoire.ca>2024-09-18 01:27:47 -0400
committerOwen Jacobson <owen@grimoire.ca>2024-09-18 12:17:46 -0400
commitcce6662d635bb2115f9f2a7bab92cc105166e761 (patch)
tree9d1edfea364a3b72cf40c78d67ce05e3e68c84df
parent921f38a73e5d58a5a6077477a8b52d2705798f55 (diff)
App methods now return errors that allow not-found cases to be distinguished.
-rw-r--r--Cargo.lock1
-rw-r--r--Cargo.toml1
-rw-r--r--src/app.rs4
-rw-r--r--src/channel/app.rs40
-rw-r--r--src/index/app.rs22
-rw-r--r--src/login/app.rs54
-rw-r--r--src/login/routes.rs24
-rw-r--r--src/repo/error.rs23
-rw-r--r--src/repo/login/extract.rs15
-rw-r--r--src/repo/mod.rs1
-rw-r--r--src/repo/token.rs4
11 files changed, 133 insertions, 56 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 448bef3..3d39e00 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -784,6 +784,7 @@ dependencies = [
"serde",
"serde_json",
"sqlx",
+ "thiserror",
"tokio",
"tokio-stream",
"uuid",
diff --git a/Cargo.toml b/Cargo.toml
index 505378d..de941b5 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -21,6 +21,7 @@ rust-embed = "8.5.0"
serde = { version = "1.0.209", features = ["derive"] }
serde_json = "1.0.128"
sqlx = { version = "0.8.1", features = ["chrono", "runtime-tokio", "sqlite"] }
+thiserror = "1.0.63"
tokio = { version = "1.40.0", features = ["rt", "macros", "rt-multi-thread"] }
tokio-stream = { version = "0.1.16", features = ["sync"] }
uuid = { version = "1.10.0", features = ["v4"] }
diff --git a/src/app.rs b/src/app.rs
index 7a43be9..1177c5e 100644
--- a/src/app.rs
+++ b/src/app.rs
@@ -1,7 +1,5 @@
use sqlx::sqlite::SqlitePool;
-use crate::error::BoxedError;
-
use crate::{
channel::app::{Broadcaster, Channels},
index::app::Index,
@@ -15,7 +13,7 @@ pub struct App {
}
impl App {
- pub async fn from(db: SqlitePool) -> Result<Self, BoxedError> {
+ pub async fn from(db: SqlitePool) -> Result<Self, sqlx::Error> {
let broadcaster = Broadcaster::from_database(&db).await?;
Ok(Self { db, broadcaster })
}
diff --git a/src/channel/app.rs b/src/channel/app.rs
index 29d9c09..e72564d 100644
--- a/src/channel/app.rs
+++ b/src/channel/app.rs
@@ -16,6 +16,7 @@ use crate::{
error::BoxedError,
repo::{
channel::{self, Channel, Provider as _},
+ error::NotFound as _,
login::Login,
},
};
@@ -30,7 +31,7 @@ impl<'a> Channels<'a> {
Self { db, broadcaster }
}
- pub async fn create(&self, name: &str) -> Result<(), BoxedError> {
+ pub async fn create(&self, name: &str) -> Result<(), InternalError> {
let mut tx = self.db.begin().await?;
let channel = tx.channels().create(name).await?;
self.broadcaster.register_channel(&channel);
@@ -39,7 +40,7 @@ impl<'a> Channels<'a> {
Ok(())
}
- pub async fn all(&self) -> Result<Vec<Channel>, BoxedError> {
+ pub async fn all(&self) -> Result<Vec<Channel>, InternalError> {
let mut tx = self.db.begin().await?;
let channels = tx.channels().all().await?;
tx.commit().await?;
@@ -53,9 +54,13 @@ impl<'a> Channels<'a> {
channel: &channel::Id,
body: &str,
sent_at: &DateTime,
- ) -> Result<(), BoxedError> {
+ ) -> Result<(), EventsError> {
let mut tx = self.db.begin().await?;
- let channel = tx.channels().by_id(channel).await?;
+ let channel = tx
+ .channels()
+ .by_id(channel)
+ .await
+ .not_found(|| EventsError::ChannelNotFound(channel.clone()))?;
let message = tx
.broadcast()
.create(login, &channel, body, sent_at)
@@ -70,11 +75,8 @@ impl<'a> Channels<'a> {
&self,
channel: &channel::Id,
resume_at: Option<&DateTime>,
- ) -> Result<impl Stream<Item = Result<broadcast::Message, BoxedError>> + 'static, BoxedError>
+ ) -> Result<impl Stream<Item = Result<broadcast::Message, BoxedError>> + 'static, EventsError>
{
- let mut tx = self.db.begin().await?;
- let channel = tx.channels().by_id(channel).await?;
-
fn skip_stale<E>(
resume_at: Option<&DateTime>,
) -> impl for<'m> FnMut(&'m broadcast::Message) -> future::Ready<Result<bool, E>> {
@@ -86,6 +88,12 @@ impl<'a> Channels<'a> {
}))
}
}
+ let mut tx = self
+ .db
+ .begin()
+ .await
+ .not_found(|| EventsError::ChannelNotFound(channel.clone()))?;
+ let channel = tx.channels().by_id(channel).await?;
let live_messages = self
.broadcaster
@@ -102,6 +110,20 @@ impl<'a> Channels<'a> {
}
}
+#[derive(Debug, thiserror::Error)]
+pub enum InternalError {
+ #[error("database error: {0}")]
+ DatabaseError(#[from] sqlx::Error),
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum EventsError {
+ #[error("channel {0} not found")]
+ ChannelNotFound(channel::Id),
+ #[error("database error: {0}")]
+ DatabaseError(#[from] sqlx::Error),
+}
+
// Clones will share the same senders collection.
#[derive(Clone)]
pub struct Broadcaster {
@@ -112,7 +134,7 @@ pub struct Broadcaster {
}
impl Broadcaster {
- pub async fn from_database(db: &SqlitePool) -> Result<Self, BoxedError> {
+ pub async fn from_database(db: &SqlitePool) -> Result<Self, sqlx::Error> {
let mut tx = db.begin().await?;
let channels = tx.channels().all().await?;
tx.commit().await?;
diff --git a/src/index/app.rs b/src/index/app.rs
index a4ef57f..a3456c0 100644
--- a/src/index/app.rs
+++ b/src/index/app.rs
@@ -1,8 +1,8 @@
use sqlx::sqlite::SqlitePool;
-use crate::{
- error::BoxedError,
- repo::channel::{self, Channel, Provider as _},
+use crate::repo::{
+ channel::{self, Channel, Provider as _},
+ error::NotFound as _,
};
pub struct Index<'a> {
@@ -14,11 +14,23 @@ impl<'a> Index<'a> {
Self { db }
}
- pub async fn channel(&self, channel: &channel::Id) -> Result<Channel, BoxedError> {
+ pub async fn channel(&self, channel: &channel::Id) -> Result<Channel, Error> {
let mut tx = self.db.begin().await?;
- let channel = tx.channels().by_id(channel).await?;
+ let channel = tx
+ .channels()
+ .by_id(channel)
+ .await
+ .not_found(|| Error::ChannelNotFound(channel.clone()))?;
tx.commit().await?;
Ok(channel)
}
}
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error("channel {0} not found")]
+ ChannelNotFound(channel::Id),
+ #[error("database error: {0}")]
+ DatabaseError(#[from] sqlx::Error),
+}
diff --git a/src/login/app.rs b/src/login/app.rs
index aec072c..f0e0571 100644
--- a/src/login/app.rs
+++ b/src/login/app.rs
@@ -3,9 +3,9 @@ use sqlx::sqlite::SqlitePool;
use super::repo::auth::Provider as _;
use crate::{
clock::DateTime,
- error::BoxedError,
password::StoredHash,
repo::{
+ error::NotFound as _,
login::{Login, Provider as _},
token::Provider as _,
},
@@ -25,7 +25,7 @@ impl<'a> Logins<'a> {
name: &str,
password: &str,
login_at: DateTime,
- ) -> Result<Option<String>, BoxedError> {
+ ) -> Result<String, LoginError> {
let mut tx = self.db.begin().await?;
let login = if let Some((login, stored_hash)) = tx.auth().for_name(name).await? {
@@ -41,39 +41,53 @@ impl<'a> Logins<'a> {
Some(tx.logins().create(name, &password_hash).await?)
};
- // If `login` is Some, then we have an identity and can issue a token.
- // If `login` is None, then neither creating a new login nor
- // authenticating an existing one succeeded, and we must reject the
- // login attempt.
- let token = if let Some(login) = login {
- Some(tx.tokens().issue(&login, login_at).await?)
- } else {
- None
- };
-
+ let login = login.ok_or(LoginError::Rejected)?;
+ let token = tx.tokens().issue(&login, login_at).await?;
tx.commit().await?;
Ok(token)
}
- pub async fn validate(
- &self,
- secret: &str,
- used_at: DateTime,
- ) -> Result<Option<Login>, BoxedError> {
+ pub async fn validate(&self, secret: &str, used_at: DateTime) -> Result<Login, ValidateError> {
let mut tx = self.db.begin().await?;
tx.tokens().expire(used_at).await?;
- let login = tx.tokens().validate(secret, used_at).await?;
+ let login = tx
+ .tokens()
+ .validate(secret, used_at)
+ .await
+ .not_found(|| ValidateError::InvalidToken)?;
tx.commit().await?;
Ok(login)
}
- pub async fn logout(&self, secret: &str) -> Result<(), BoxedError> {
+ pub async fn logout(&self, secret: &str) -> Result<(), ValidateError> {
let mut tx = self.db.begin().await?;
- tx.tokens().revoke(secret).await?;
+ tx.tokens()
+ .revoke(secret)
+ .await
+ .not_found(|| ValidateError::InvalidToken)?;
+
tx.commit().await?;
Ok(())
}
}
+
+#[derive(Debug, thiserror::Error)]
+pub enum LoginError {
+ #[error("invalid login")]
+ Rejected,
+ #[error("database error: {0}")]
+ DatabaseError(#[from] sqlx::Error),
+ #[error("password hash error: {0}")]
+ PasswordHashError(#[from] password_hash::Error),
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum ValidateError {
+ #[error("invalid token")]
+ InvalidToken,
+ #[error("database error: {0}")]
+ DatabaseError(#[from] sqlx::Error),
+}
diff --git a/src/login/routes.rs b/src/login/routes.rs
index 816926e..3c58b10 100644
--- a/src/login/routes.rs
+++ b/src/login/routes.rs
@@ -8,7 +8,7 @@ use axum::{
use crate::{app::App, clock::RequestedAt, error::InternalError};
-use super::extract::IdentityToken;
+use super::{app::LoginError, extract::IdentityToken};
pub fn router() -> Router<App> {
Router::new()
@@ -28,30 +28,28 @@ async fn on_login(
identity: IdentityToken,
Form(form): Form<LoginRequest>,
) -> Result<impl IntoResponse, InternalError> {
- let token = app.logins().login(&form.name, &form.password, now).await?;
-
- let resp = if let Some(token) = token {
- let identity = identity.set(&token);
- (identity, LoginResponse::Successful)
- } else {
- (identity, LoginResponse::Rejected)
- };
-
- Ok(resp)
+ match app.logins().login(&form.name, &form.password, now).await {
+ Ok(token) => {
+ let identity = identity.set(&token);
+ Ok(LoginResponse::Successful(identity))
+ }
+ Err(LoginError::Rejected) => Ok(LoginResponse::Rejected),
+ Err(other) => Err(other.into()),
+ }
}
enum LoginResponse {
Rejected,
- Successful,
+ Successful(IdentityToken),
}
impl IntoResponse for LoginResponse {
fn into_response(self) -> Response {
match self {
+ Self::Successful(identity) => (identity, Redirect::to("/")).into_response(),
Self::Rejected => {
(StatusCode::UNAUTHORIZED, "invalid name or password").into_response()
}
- Self::Successful => Redirect::to("/").into_response(),
}
}
}
diff --git a/src/repo/error.rs b/src/repo/error.rs
new file mode 100644
index 0000000..a5961e2
--- /dev/null
+++ b/src/repo/error.rs
@@ -0,0 +1,23 @@
+pub trait NotFound {
+ type Ok;
+ fn not_found<E, F>(self, map: F) -> Result<Self::Ok, E>
+ where
+ E: From<sqlx::Error>,
+ F: FnOnce() -> E;
+}
+
+impl<T> NotFound for Result<T, sqlx::Error> {
+ type Ok = T;
+
+ fn not_found<E, F>(self, map: F) -> Result<T, E>
+ where
+ E: From<sqlx::Error>,
+ F: FnOnce() -> E,
+ {
+ match self {
+ Err(sqlx::Error::RowNotFound) => Err(map()),
+ Err(other) => Err(other.into()),
+ Ok(value) => Ok(value),
+ }
+ }
+}
diff --git a/src/repo/login/extract.rs b/src/repo/login/extract.rs
index a068bc0..a45a1cd 100644
--- a/src/repo/login/extract.rs
+++ b/src/repo/login/extract.rs
@@ -5,7 +5,12 @@ use axum::{
};
use super::Login;
-use crate::{app::App, clock::RequestedAt, error::InternalError, login::extract::IdentityToken};
+use crate::{
+ app::App,
+ clock::RequestedAt,
+ error::InternalError,
+ login::{app::ValidateError, extract::IdentityToken},
+};
#[async_trait::async_trait]
impl FromRequestParts<App> for Login {
@@ -22,9 +27,11 @@ impl FromRequestParts<App> for Login {
let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?;
let app = State::<App>::from_request_parts(parts, state).await?;
- let login = app.logins().validate(secret, used_at).await?;
-
- login.ok_or(LoginError::Unauthorized)
+ match app.logins().validate(secret, used_at).await {
+ Ok(login) => Ok(login),
+ Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized),
+ Err(other) => Err(other.into()),
+ }
}
}
diff --git a/src/repo/mod.rs b/src/repo/mod.rs
index d8995a3..f36f0da 100644
--- a/src/repo/mod.rs
+++ b/src/repo/mod.rs
@@ -1,4 +1,5 @@
pub mod channel;
+pub mod error;
pub mod login;
pub mod message;
pub mod token;
diff --git a/src/repo/token.rs b/src/repo/token.rs
index 01a982e..5674c92 100644
--- a/src/repo/token.rs
+++ b/src/repo/token.rs
@@ -88,7 +88,7 @@ impl<'c> Tokens<'c> {
&mut self,
secret: &str,
used_at: DateTime,
- ) -> Result<Option<Login>, sqlx::Error> {
+ ) -> Result<Login, sqlx::Error> {
// I would use `update … returning` to do this in one query, but
// sqlite3, as of this writing, does not allow an update's `returning`
// clause to reference columns from tables joined into the update. Two
@@ -117,7 +117,7 @@ impl<'c> Tokens<'c> {
"#,
secret,
)
- .fetch_optional(&mut *self.0)
+ .fetch_one(&mut *self.0)
.await?;
Ok(login)