summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/boot/app.rs34
-rw-r--r--src/channel/app.rs49
-rw-r--r--src/channel/mod.rs5
-rw-r--r--src/channel/name.rs30
-rw-r--r--src/channel/repo.rs212
-rw-r--r--src/channel/routes/post.rs3
-rw-r--r--src/channel/snapshot.rs4
-rw-r--r--src/db/mod.rs15
-rw-r--r--src/event/app.rs30
-rw-r--r--src/event/routes/get.rs4
-rw-r--r--src/invite/app.rs3
-rw-r--r--src/invite/mod.rs4
-rw-r--r--src/invite/repo.rs3
-rw-r--r--src/invite/routes/invite/post.rs3
-rw-r--r--src/lib.rs3
-rw-r--r--src/login/app.rs3
-rw-r--r--src/login/mod.rs4
-rw-r--r--src/login/name.rs28
-rw-r--r--src/login/password.rs2
-rw-r--r--src/login/repo.rs98
-rw-r--r--src/login/routes/login/post.rs3
-rw-r--r--src/login/routes/logout/test.rs1
-rw-r--r--src/login/snapshot.rs3
-rw-r--r--src/message/app.rs13
-rw-r--r--src/message/body.rs2
-rw-r--r--src/name.rs85
-rw-r--r--src/normalize/mod.rs36
-rw-r--r--src/normalize/string.rs (renamed from src/nfc.rs)57
-rw-r--r--src/setup/app.rs3
-rw-r--r--src/setup/routes/post.rs3
-rw-r--r--src/test/fixtures/channel.rs3
-rw-r--r--src/test/fixtures/login.rs3
-rw-r--r--src/token/app.rs37
-rw-r--r--src/token/repo/auth.rs68
-rw-r--r--src/token/repo/mod.rs2
-rw-r--r--src/token/repo/token.rs68
36 files changed, 628 insertions, 296 deletions
diff --git a/src/boot/app.rs b/src/boot/app.rs
index ef48b2f..1d88608 100644
--- a/src/boot/app.rs
+++ b/src/boot/app.rs
@@ -2,8 +2,11 @@ use sqlx::sqlite::SqlitePool;
use super::Snapshot;
use crate::{
- channel::repo::Provider as _, event::repo::Provider as _, login::repo::Provider as _,
+ channel::{self, repo::Provider as _},
+ event::repo::Provider as _,
+ login::{self, repo::Provider as _},
message::repo::Provider as _,
+ name,
};
pub struct Boot<'a> {
@@ -15,7 +18,7 @@ impl<'a> Boot<'a> {
Self { db }
}
- pub async fn snapshot(&self) -> Result<Snapshot, sqlx::Error> {
+ pub async fn snapshot(&self) -> Result<Snapshot, Error> {
let mut tx = self.db.begin().await?;
let resume_point = tx.sequence().current().await?;
@@ -48,3 +51,30 @@ impl<'a> Boot<'a> {
})
}
}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub enum Error {
+ Name(#[from] name::Error),
+ Database(#[from] sqlx::Error),
+}
+
+impl From<login::repo::LoadError> for Error {
+ fn from(error: login::repo::LoadError) -> Self {
+ use login::repo::LoadError;
+ match error {
+ LoadError::Name(error) => error.into(),
+ LoadError::Database(error) => error.into(),
+ }
+ }
+}
+
+impl From<channel::repo::LoadError> for Error {
+ fn from(error: channel::repo::LoadError) -> Self {
+ use channel::repo::LoadError;
+ match error {
+ LoadError::Name(error) => error.into(),
+ LoadError::Database(error) => error.into(),
+ }
+ }
+}
diff --git a/src/channel/app.rs b/src/channel/app.rs
index ea60943..b8ceeb0 100644
--- a/src/channel/app.rs
+++ b/src/channel/app.rs
@@ -2,12 +2,16 @@ use chrono::TimeDelta;
use itertools::Itertools;
use sqlx::sqlite::SqlitePool;
-use super::{repo::Provider as _, Channel, History, Id, Name};
+use super::{
+ repo::{LoadError, Provider as _},
+ Channel, History, Id,
+};
use crate::{
clock::DateTime,
db::{Duplicate as _, NotFound as _},
event::{repo::Provider as _, Broadcaster, Event, Sequence},
message::repo::Provider as _,
+ name::{self, Name},
};
pub struct Channels<'a> {
@@ -38,7 +42,7 @@ impl<'a> Channels<'a> {
// This function is careless with respect to time, and gets you the channel as
// it exists in the specific moment when you call it.
- pub async fn get(&self, channel: &Id) -> Result<Option<Channel>, sqlx::Error> {
+ pub async fn get(&self, channel: &Id) -> Result<Option<Channel>, Error> {
let mut tx = self.db.begin().await?;
let channel = tx.channels().by_id(channel).await.optional()?;
tx.commit().await?;
@@ -88,7 +92,7 @@ impl<'a> Channels<'a> {
Ok(())
}
- pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> {
+ pub async fn expire(&self, relative_to: &DateTime) -> Result<(), ExpireError> {
// Somewhat arbitrarily, expire after 90 days.
let expire_at = relative_to.to_owned() - TimeDelta::days(90);
@@ -137,6 +141,17 @@ pub enum CreateError {
DuplicateName(Name),
#[error(transparent)]
Database(#[from] sqlx::Error),
+ #[error(transparent)]
+ Name(#[from] name::Error),
+}
+
+impl From<LoadError> for CreateError {
+ fn from(error: LoadError) -> Self {
+ match error {
+ LoadError::Database(error) => error.into(),
+ LoadError::Name(error) => error.into(),
+ }
+ }
}
#[derive(Debug, thiserror::Error)]
@@ -147,4 +162,32 @@ pub enum Error {
Deleted(Id),
#[error(transparent)]
Database(#[from] sqlx::Error),
+ #[error(transparent)]
+ Name(#[from] name::Error),
+}
+
+impl From<LoadError> for Error {
+ fn from(error: LoadError) -> Self {
+ match error {
+ LoadError::Database(error) => error.into(),
+ LoadError::Name(error) => error.into(),
+ }
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum ExpireError {
+ #[error(transparent)]
+ Database(#[from] sqlx::Error),
+ #[error(transparent)]
+ Name(#[from] name::Error),
+}
+
+impl From<LoadError> for ExpireError {
+ fn from(error: LoadError) -> Self {
+ match error {
+ LoadError::Database(error) => error.into(),
+ LoadError::Name(error) => error.into(),
+ }
+ }
}
diff --git a/src/channel/mod.rs b/src/channel/mod.rs
index fb13e92..eb8200b 100644
--- a/src/channel/mod.rs
+++ b/src/channel/mod.rs
@@ -2,11 +2,8 @@ pub mod app;
pub mod event;
mod history;
mod id;
-mod name;
pub mod repo;
mod routes;
mod snapshot;
-pub use self::{
- event::Event, history::History, id::Id, name::Name, routes::router, snapshot::Channel,
-};
+pub use self::{event::Event, history::History, id::Id, routes::router, snapshot::Channel};
diff --git a/src/channel/name.rs b/src/channel/name.rs
deleted file mode 100644
index fc82dec..0000000
--- a/src/channel/name.rs
+++ /dev/null
@@ -1,30 +0,0 @@
-use std::fmt;
-
-use crate::nfc;
-
-#[derive(
- Clone, Debug, Default, Eq, PartialEq, serde::Deserialize, serde::Serialize, sqlx::Type,
-)]
-#[serde(transparent)]
-#[sqlx(transparent)]
-pub struct Name(nfc::String);
-
-impl fmt::Display for Name {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- let Self(name) = self;
- name.fmt(f)
- }
-}
-
-impl From<String> for Name {
- fn from(name: String) -> Self {
- Self(name.into())
- }
-}
-
-impl From<Name> for String {
- fn from(name: Name) -> Self {
- let Name(name) = name;
- name.into()
- }
-}
diff --git a/src/channel/repo.rs b/src/channel/repo.rs
index 3353bfd..4baa95b 100644
--- a/src/channel/repo.rs
+++ b/src/channel/repo.rs
@@ -1,9 +1,12 @@
+use futures::stream::{StreamExt as _, TryStreamExt as _};
use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
use crate::{
- channel::{Channel, History, Id, Name},
+ channel::{Channel, History, Id},
clock::DateTime,
+ db::NotFound,
event::{Instant, ResumePoint, Sequence},
+ name::{self, Name},
};
pub trait Provider {
@@ -21,132 +24,160 @@ pub struct Channels<'t>(&'t mut SqliteConnection);
impl<'c> Channels<'c> {
pub async fn create(&mut self, name: &Name, created: &Instant) -> Result<History, sqlx::Error> {
let id = Id::generate();
- let channel = sqlx::query!(
+ let name = name.clone();
+ let display_name = name.display();
+ let canonical_name = name.canonical();
+ let created = *created;
+
+ sqlx::query!(
r#"
insert
- into channel (id, name, created_at, created_sequence)
- values ($1, $2, $3, $4)
- returning
- id as "id: Id",
- name as "name!: Name", -- known non-null as we just set it
- created_at as "created_at: DateTime",
- created_sequence as "created_sequence: Sequence"
+ into channel (id, created_at, created_sequence)
+ values ($1, $2, $3)
"#,
id,
- name,
created.at,
created.sequence,
)
- .map(|row| History {
+ .execute(&mut *self.0)
+ .await?;
+
+ sqlx::query!(
+ r#"
+ insert into channel_name (id, display_name, canonical_name)
+ values ($1, $2, $3)
+ "#,
+ id,
+ display_name,
+ canonical_name,
+ )
+ .execute(&mut *self.0)
+ .await?;
+
+ let channel = History {
channel: Channel {
- id: row.id,
- name: row.name,
+ id,
+ name: name.clone(),
deleted_at: None,
},
- created: Instant::new(row.created_at, row.created_sequence),
+ created,
deleted: None,
- })
- .fetch_one(&mut *self.0)
- .await?;
+ };
Ok(channel)
}
- pub async fn by_id(&mut self, channel: &Id) -> Result<History, sqlx::Error> {
+ pub async fn by_id(&mut self, channel: &Id) -> Result<History, LoadError> {
let channel = sqlx::query!(
r#"
select
id as "id: Id",
- channel.name as "name: Name",
+ name.display_name as "display_name?: String",
+ name.canonical_name as "canonical_name?: String",
channel.created_at as "created_at: DateTime",
channel.created_sequence as "created_sequence: Sequence",
deleted.deleted_at as "deleted_at?: DateTime",
deleted.deleted_sequence as "deleted_sequence?: Sequence"
from channel
+ left join channel_name as name
+ using (id)
left join channel_deleted as deleted
using (id)
where id = $1
"#,
channel,
)
- .map(|row| History {
- channel: Channel {
- id: row.id,
- name: row.name.unwrap_or_default(),
- deleted_at: row.deleted_at,
- },
- created: Instant::new(row.created_at, row.created_sequence),
- deleted: Instant::optional(row.deleted_at, row.deleted_sequence),
+ .map(|row| {
+ Ok::<_, name::Error>(History {
+ channel: Channel {
+ id: row.id,
+ name: Name::optional(row.display_name, row.canonical_name)?.unwrap_or_default(),
+ deleted_at: row.deleted_at,
+ },
+ created: Instant::new(row.created_at, row.created_sequence),
+ deleted: Instant::optional(row.deleted_at, row.deleted_sequence),
+ })
})
.fetch_one(&mut *self.0)
- .await?;
+ .await??;
Ok(channel)
}
- pub async fn all(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, sqlx::Error> {
+ pub async fn all(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, LoadError> {
let channels = sqlx::query!(
r#"
select
id as "id: Id",
- channel.name as "name: Name",
+ name.display_name as "display_name: String",
+ name.canonical_name as "canonical_name: String",
channel.created_at as "created_at: DateTime",
channel.created_sequence as "created_sequence: Sequence",
- deleted.deleted_at as "deleted_at: DateTime",
- deleted.deleted_sequence as "deleted_sequence: Sequence"
+ deleted.deleted_at as "deleted_at?: DateTime",
+ deleted.deleted_sequence as "deleted_sequence?: Sequence"
from channel
+ left join channel_name as name
+ using (id)
left join channel_deleted as deleted
using (id)
where coalesce(channel.created_sequence <= $1, true)
- order by channel.name
+ order by name.canonical_name
"#,
resume_at,
)
- .map(|row| History {
- channel: Channel {
- id: row.id,
- name: row.name.unwrap_or_default(),
- deleted_at: row.deleted_at,
- },
- created: Instant::new(row.created_at, row.created_sequence),
- deleted: Instant::optional(row.deleted_at, row.deleted_sequence),
+ .map(|row| {
+ Ok::<_, name::Error>(History {
+ channel: Channel {
+ id: row.id,
+ name: Name::optional(row.display_name, row.canonical_name)?.unwrap_or_default(),
+ deleted_at: row.deleted_at,
+ },
+ created: Instant::new(row.created_at, row.created_sequence),
+ deleted: Instant::optional(row.deleted_at, row.deleted_sequence),
+ })
})
- .fetch_all(&mut *self.0)
+ .fetch(&mut *self.0)
+ .map(|res| Ok::<_, LoadError>(res??))
+ .try_collect()
.await?;
Ok(channels)
}
- pub async fn replay(
- &mut self,
- resume_at: Option<Sequence>,
- ) -> Result<Vec<History>, sqlx::Error> {
+ pub async fn replay(&mut self, resume_at: Option<Sequence>) -> Result<Vec<History>, LoadError> {
let channels = sqlx::query!(
r#"
select
id as "id: Id",
- channel.name as "name: Name",
+ name.display_name as "display_name: String",
+ name.canonical_name as "canonical_name: String",
channel.created_at as "created_at: DateTime",
channel.created_sequence as "created_sequence: Sequence",
- deleted.deleted_at as "deleted_at: DateTime",
- deleted.deleted_sequence as "deleted_sequence: Sequence"
+ deleted.deleted_at as "deleted_at?: DateTime",
+ deleted.deleted_sequence as "deleted_sequence?: Sequence"
from channel
+ left join channel_name as name
+ using (id)
left join channel_deleted as deleted
using (id)
where coalesce(channel.created_sequence > $1, true)
"#,
resume_at,
)
- .map(|row| History {
- channel: Channel {
- id: row.id,
- name: row.name.unwrap_or_default(),
- deleted_at: row.deleted_at,
- },
- created: Instant::new(row.created_at, row.created_sequence),
- deleted: Instant::optional(row.deleted_at, row.deleted_sequence),
+ .map(|row| {
+ Ok::<_, name::Error>(History {
+ channel: Channel {
+ id: row.id,
+ name: Name::optional(row.display_name, row.canonical_name)?.unwrap_or_default(),
+ deleted_at: row.deleted_at,
+ },
+ created: Instant::new(row.created_at, row.created_sequence),
+ deleted: Instant::optional(row.deleted_at, row.deleted_sequence),
+ })
})
- .fetch_all(&mut *self.0)
+ .fetch(&mut *self.0)
+ .map(|res| Ok::<_, LoadError>(res??))
+ .try_collect()
.await?;
Ok(channels)
@@ -156,19 +187,18 @@ impl<'c> Channels<'c> {
&mut self,
channel: &History,
deleted: &Instant,
- ) -> Result<History, sqlx::Error> {
+ ) -> Result<History, LoadError> {
let id = channel.id();
- sqlx::query_scalar!(
+ sqlx::query!(
r#"
insert into channel_deleted (id, deleted_at, deleted_sequence)
values ($1, $2, $3)
- returning 1 as "deleted: bool"
"#,
id,
deleted.at,
deleted.sequence,
)
- .fetch_one(&mut *self.0)
+ .execute(&mut *self.0)
.await?;
// Small social responsibility hack here: when a channel is deleted, its name is
@@ -179,16 +209,14 @@ impl<'c> Channels<'c> {
// This also avoids the need for a separate name reservation table to ensure
// that live channels have unique names, since the `channel` table's name field
// is unique over non-null values.
- sqlx::query_scalar!(
+ sqlx::query!(
r#"
- update channel
- set name = null
+ delete from channel_name
where id = $1
- returning 1 as "updated: bool"
"#,
id,
)
- .fetch_one(&mut *self.0)
+ .execute(&mut *self.0)
.await?;
let channel = self.by_id(id).await?;
@@ -230,38 +258,66 @@ impl<'c> Channels<'c> {
Ok(())
}
- pub async fn expired(&mut self, expired_at: &DateTime) -> Result<Vec<History>, sqlx::Error> {
+ pub async fn expired(&mut self, expired_at: &DateTime) -> Result<Vec<History>, LoadError> {
let channels = sqlx::query!(
r#"
select
channel.id as "id: Id",
- channel.name as "name: Name",
+ name.display_name as "display_name: String",
+ name.canonical_name as "canonical_name: String",
channel.created_at as "created_at: DateTime",
channel.created_sequence as "created_sequence: Sequence",
deleted.deleted_at as "deleted_at?: DateTime",
deleted.deleted_sequence as "deleted_sequence?: Sequence"
from channel
+ left join channel_name as name
+ using (id)
left join channel_deleted as deleted
using (id)
left join message
+ on channel.id = message.channel
where channel.created_at < $1
and message.id is null
and deleted.id is null
"#,
expired_at,
)
- .map(|row| History {
- channel: Channel {
- id: row.id,
- name: row.name.unwrap_or_default(),
- deleted_at: row.deleted_at,
- },
- created: Instant::new(row.created_at, row.created_sequence),
- deleted: Instant::optional(row.deleted_at, row.deleted_sequence),
+ .map(|row| {
+ Ok::<_, name::Error>(History {
+ channel: Channel {
+ id: row.id,
+ name: Name::optional(row.display_name, row.canonical_name)?.unwrap_or_default(),
+ deleted_at: row.deleted_at,
+ },
+ created: Instant::new(row.created_at, row.created_sequence),
+ deleted: Instant::optional(row.deleted_at, row.deleted_sequence),
+ })
})
- .fetch_all(&mut *self.0)
+ .fetch(&mut *self.0)
+ .map(|res| Ok::<_, LoadError>(res??))
+ .try_collect()
.await?;
Ok(channels)
}
}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub enum LoadError {
+ Database(#[from] sqlx::Error),
+ Name(#[from] name::Error),
+}
+
+impl<T> NotFound for Result<T, LoadError> {
+ type Ok = T;
+ type Error = LoadError;
+
+ fn optional(self) -> Result<Option<T>, LoadError> {
+ match self {
+ Ok(value) => Ok(Some(value)),
+ Err(LoadError::Database(sqlx::Error::RowNotFound)) => Ok(None),
+ Err(other) => Err(other),
+ }
+ }
+}
diff --git a/src/channel/routes/post.rs b/src/channel/routes/post.rs
index d354f79..9781dd7 100644
--- a/src/channel/routes/post.rs
+++ b/src/channel/routes/post.rs
@@ -6,10 +6,11 @@ use axum::{
use crate::{
app::App,
- channel::{app, Channel, Name},
+ channel::{app, Channel},
clock::RequestedAt,
error::Internal,
login::Login,
+ name::Name,
};
pub async fn handler(
diff --git a/src/channel/snapshot.rs b/src/channel/snapshot.rs
index dc2894d..129c0d6 100644
--- a/src/channel/snapshot.rs
+++ b/src/channel/snapshot.rs
@@ -1,8 +1,8 @@
use super::{
event::{Created, Event},
- Id, Name,
+ Id,
};
-use crate::clock::DateTime;
+use crate::{clock::DateTime, name::Name};
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)]
pub struct Channel {
diff --git a/src/db/mod.rs b/src/db/mod.rs
index 6005813..e0522d4 100644
--- a/src/db/mod.rs
+++ b/src/db/mod.rs
@@ -130,14 +130,17 @@ pub enum Error {
Rejected(String, String),
}
-pub trait NotFound {
+pub trait NotFound: Sized {
type Ok;
type Error;
fn not_found<E, F>(self, map: F) -> Result<Self::Ok, E>
where
E: From<Self::Error>,
- F: FnOnce() -> E;
+ F: FnOnce() -> E,
+ {
+ self.optional()?.ok_or_else(map)
+ }
fn optional(self) -> Result<Option<Self::Ok>, Self::Error>;
}
@@ -153,14 +156,6 @@ impl<T> NotFound for Result<T, sqlx::Error> {
Err(other) => Err(other),
}
}
-
- fn not_found<E, F>(self, map: F) -> Result<T, E>
- where
- E: From<sqlx::Error>,
- F: FnOnce() -> E,
- {
- self.optional()?.ok_or_else(map)
- }
}
pub trait Duplicate {
diff --git a/src/event/app.rs b/src/event/app.rs
index 951ce25..c754388 100644
--- a/src/event/app.rs
+++ b/src/event/app.rs
@@ -11,6 +11,7 @@ use crate::{
channel::{self, repo::Provider as _},
login::{self, repo::Provider as _},
message::{self, repo::Provider as _},
+ name,
};
pub struct Events<'a> {
@@ -26,7 +27,7 @@ impl<'a> Events<'a> {
pub async fn subscribe(
&self,
resume_at: impl Into<ResumePoint>,
- ) -> Result<impl Stream<Item = Event> + std::fmt::Debug, sqlx::Error> {
+ ) -> Result<impl Stream<Item = Event> + std::fmt::Debug, Error> {
let resume_at = resume_at.into();
// Subscribe before retrieving, to catch messages broadcast while we're
// querying the DB. We'll prune out duplicates later.
@@ -81,3 +82,30 @@ impl<'a> Events<'a> {
move |event| future::ready(filter(event))
}
}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub enum Error {
+ Database(#[from] sqlx::Error),
+ Name(#[from] name::Error),
+}
+
+impl From<login::repo::LoadError> for Error {
+ fn from(error: login::repo::LoadError) -> Self {
+ use login::repo::LoadError;
+ match error {
+ LoadError::Database(error) => error.into(),
+ LoadError::Name(error) => error.into(),
+ }
+ }
+}
+
+impl From<channel::repo::LoadError> for Error {
+ fn from(error: channel::repo::LoadError) -> Self {
+ use channel::repo::LoadError;
+ match error {
+ LoadError::Database(error) => error.into(),
+ LoadError::Name(error) => error.into(),
+ }
+ }
+}
diff --git a/src/event/routes/get.rs b/src/event/routes/get.rs
index 357845a..22e8762 100644
--- a/src/event/routes/get.rs
+++ b/src/event/routes/get.rs
@@ -12,7 +12,7 @@ use futures::stream::{Stream, StreamExt as _};
use crate::{
app::App,
error::{Internal, Unauthorized},
- event::{extract::LastEventId, Event, ResumePoint, Sequence, Sequenced as _},
+ event::{app, extract::LastEventId, Event, ResumePoint, Sequence, Sequenced as _},
token::{app::ValidateError, extract::Identity},
};
@@ -69,7 +69,7 @@ impl TryFrom<Event> for sse::Event {
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub enum Error {
- Database(#[from] sqlx::Error),
+ Subscribe(#[from] app::Error),
Validate(#[from] ValidateError),
}
diff --git a/src/invite/app.rs b/src/invite/app.rs
index 285a819..64ba753 100644
--- a/src/invite/app.rs
+++ b/src/invite/app.rs
@@ -6,7 +6,8 @@ use crate::{
clock::DateTime,
db::{Duplicate as _, NotFound as _},
event::repo::Provider as _,
- login::{repo::Provider as _, Login, Name, Password},
+ login::{repo::Provider as _, Login, Password},
+ name::Name,
token::{repo::Provider as _, Secret},
};
diff --git a/src/invite/mod.rs b/src/invite/mod.rs
index abf1c3a..d59fb9c 100644
--- a/src/invite/mod.rs
+++ b/src/invite/mod.rs
@@ -3,7 +3,7 @@ mod id;
mod repo;
mod routes;
-use crate::{clock::DateTime, login};
+use crate::{clock::DateTime, login, normalize::nfc};
pub use self::{id::Id, routes::router};
@@ -17,6 +17,6 @@ pub struct Invite {
#[derive(serde::Serialize)]
pub struct Summary {
pub id: Id,
- pub issuer: String,
+ pub issuer: nfc::String,
pub issued_at: DateTime,
}
diff --git a/src/invite/repo.rs b/src/invite/repo.rs
index 643f5b7..02f4e42 100644
--- a/src/invite/repo.rs
+++ b/src/invite/repo.rs
@@ -4,6 +4,7 @@ use super::{Id, Invite, Summary};
use crate::{
clock::DateTime,
login::{self, Login},
+ normalize::nfc,
};
pub trait Provider {
@@ -70,7 +71,7 @@ impl<'c> Invites<'c> {
select
invite.id as "invite_id: Id",
issuer.id as "issuer_id: login::Id",
- issuer.name as "issuer_name",
+ issuer.display_name as "issuer_name: nfc::String",
invite.issued_at as "invite_issued_at: DateTime"
from invite
join login as issuer on (invite.issuer = issuer.id)
diff --git a/src/invite/routes/invite/post.rs b/src/invite/routes/invite/post.rs
index 8160465..a41207a 100644
--- a/src/invite/routes/invite/post.rs
+++ b/src/invite/routes/invite/post.rs
@@ -9,7 +9,8 @@ use crate::{
clock::RequestedAt,
error::{Internal, NotFound},
invite::app,
- login::{Login, Name, Password},
+ login::{Login, Password},
+ name::Name,
token::extract::IdentityToken,
};
diff --git a/src/lib.rs b/src/lib.rs
index 4d0d9b9..84b8dfc 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -16,7 +16,8 @@ mod id;
mod invite;
mod login;
mod message;
-mod nfc;
+mod name;
+mod normalize;
mod setup;
#[cfg(test)]
mod test;
diff --git a/src/login/app.rs b/src/login/app.rs
index ebc1c00..37f1249 100644
--- a/src/login/app.rs
+++ b/src/login/app.rs
@@ -1,9 +1,10 @@
use sqlx::sqlite::SqlitePool;
-use super::{repo::Provider as _, Login, Name, Password};
+use super::{repo::Provider as _, Login, Password};
use crate::{
clock::DateTime,
event::{repo::Provider as _, Broadcaster, Event},
+ name::Name,
};
pub struct Logins<'a> {
diff --git a/src/login/mod.rs b/src/login/mod.rs
index 71d5bfc..98cc3d7 100644
--- a/src/login/mod.rs
+++ b/src/login/mod.rs
@@ -4,13 +4,11 @@ pub mod event;
pub mod extract;
mod history;
mod id;
-mod name;
pub mod password;
pub mod repo;
mod routes;
mod snapshot;
pub use self::{
- event::Event, history::History, id::Id, name::Name, password::Password, routes::router,
- snapshot::Login,
+ event::Event, history::History, id::Id, password::Password, routes::router, snapshot::Login,
};
diff --git a/src/login/name.rs b/src/login/name.rs
deleted file mode 100644
index d882ff9..0000000
--- a/src/login/name.rs
+++ /dev/null
@@ -1,28 +0,0 @@
-use std::fmt;
-
-use crate::nfc;
-
-#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize, sqlx::Type)]
-#[serde(transparent)]
-#[sqlx(transparent)]
-pub struct Name(nfc::String);
-
-impl fmt::Display for Name {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- let Self(name) = self;
- name.fmt(f)
- }
-}
-
-impl From<String> for Name {
- fn from(name: String) -> Self {
- Self(name.into())
- }
-}
-
-impl From<Name> for String {
- fn from(name: Name) -> Self {
- let Name(name) = name;
- name.into()
- }
-}
diff --git a/src/login/password.rs b/src/login/password.rs
index f9ecf37..c27c950 100644
--- a/src/login/password.rs
+++ b/src/login/password.rs
@@ -4,7 +4,7 @@ use argon2::Argon2;
use password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString};
use rand_core::OsRng;
-use crate::nfc;
+use crate::normalize::nfc;
#[derive(sqlx::Type)]
#[sqlx(transparent)]
diff --git a/src/login/repo.rs b/src/login/repo.rs
index 204329f..6021f26 100644
--- a/src/login/repo.rs
+++ b/src/login/repo.rs
@@ -1,9 +1,11 @@
+use futures::stream::{StreamExt as _, TryStreamExt as _};
use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
use crate::{
clock::DateTime,
event::{Instant, ResumePoint, Sequence},
- login::{password::StoredHash, History, Id, Login, Name},
+ login::{password::StoredHash, History, Id, Login},
+ name::{self, Name},
};
pub trait Provider {
@@ -26,43 +28,43 @@ impl<'c> Logins<'c> {
created: &Instant,
) -> Result<History, sqlx::Error> {
let id = Id::generate();
+ let display_name = name.display();
+ let canonical_name = name.canonical();
- let login = sqlx::query!(
+ sqlx::query!(
r#"
insert
- into login (id, name, password_hash, created_sequence, created_at)
- values ($1, $2, $3, $4, $5)
- returning
- id as "id: Id",
- name as "name: Name",
- created_sequence as "created_sequence: Sequence",
- created_at as "created_at: DateTime"
+ into login (id, display_name, canonical_name, password_hash, created_sequence, created_at)
+ values ($1, $2, $3, $4, $5, $6)
"#,
id,
- name,
+ display_name,
+ canonical_name,
password_hash,
created.sequence,
created.at,
)
- .map(|row| History {
+ .execute(&mut *self.0)
+ .await?;
+
+ let login = History {
+ created: *created,
login: Login {
- id: row.id,
- name: row.name,
+ id,
+ name: name.clone(),
},
- created: Instant::new(row.created_at, row.created_sequence),
- })
- .fetch_one(&mut *self.0)
- .await?;
+ };
Ok(login)
}
- pub async fn all(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, sqlx::Error> {
- let channels = sqlx::query!(
+ pub async fn all(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, LoadError> {
+ let logins = sqlx::query!(
r#"
select
id as "id: Id",
- name as "name: Name",
+ display_name as "display_name: String",
+ canonical_name as "canonical_name: String",
created_sequence as "created_sequence: Sequence",
created_at as "created_at: DateTime"
from login
@@ -71,24 +73,29 @@ impl<'c> Logins<'c> {
"#,
resume_at,
)
- .map(|row| History {
- login: Login {
- id: row.id,
- name: row.name,
- },
- created: Instant::new(row.created_at, row.created_sequence),
+ .map(|row| {
+ Ok::<_, LoadError>(History {
+ login: Login {
+ id: row.id,
+ name: Name::new(row.display_name, row.canonical_name)?,
+ },
+ created: Instant::new(row.created_at, row.created_sequence),
+ })
})
- .fetch_all(&mut *self.0)
+ .fetch(&mut *self.0)
+ .map(|res| res?)
+ .try_collect()
.await?;
- Ok(channels)
+ Ok(logins)
}
- pub async fn replay(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, sqlx::Error> {
- let messages = sqlx::query!(
+ pub async fn replay(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, LoadError> {
+ let logins = sqlx::query!(
r#"
select
id as "id: Id",
- name as "name: Name",
+ display_name as "display_name: String",
+ canonical_name as "canonical_name: String",
created_sequence as "created_sequence: Sequence",
created_at as "created_at: DateTime"
from login
@@ -96,22 +103,27 @@ impl<'c> Logins<'c> {
"#,
resume_at,
)
- .map(|row| History {
- login: Login {
- id: row.id,
- name: row.name,
- },
- created: Instant::new(row.created_at, row.created_sequence),
+ .map(|row| {
+ Ok::<_, name::Error>(History {
+ login: Login {
+ id: row.id,
+ name: Name::new(row.display_name, row.canonical_name)?,
+ },
+ created: Instant::new(row.created_at, row.created_sequence),
+ })
})
- .fetch_all(&mut *self.0)
+ .fetch(&mut *self.0)
+ .map(|res| Ok::<_, LoadError>(res??))
+ .try_collect()
.await?;
- Ok(messages)
+ Ok(logins)
}
}
-impl<'t> From<&'t mut SqliteConnection> for Logins<'t> {
- fn from(tx: &'t mut SqliteConnection) -> Self {
- Self(tx)
- }
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub enum LoadError {
+ Database(#[from] sqlx::Error),
+ Name(#[from] name::Error),
}
diff --git a/src/login/routes/login/post.rs b/src/login/routes/login/post.rs
index 7a685e2..20430db 100644
--- a/src/login/routes/login/post.rs
+++ b/src/login/routes/login/post.rs
@@ -8,7 +8,8 @@ use crate::{
app::App,
clock::RequestedAt,
error::Internal,
- login::{Login, Name, Password},
+ login::{Login, Password},
+ name::Name,
token::{app, extract::IdentityToken},
};
diff --git a/src/login/routes/logout/test.rs b/src/login/routes/logout/test.rs
index 0e70e4c..91837fe 100644
--- a/src/login/routes/logout/test.rs
+++ b/src/login/routes/logout/test.rs
@@ -33,7 +33,6 @@ async fn successful() {
assert_eq!(StatusCode::NO_CONTENT, response_status);
// Verify the semantics
-
let error = app
.tokens()
.validate(&secret, &now)
diff --git a/src/login/snapshot.rs b/src/login/snapshot.rs
index 85800e4..e1eb96c 100644
--- a/src/login/snapshot.rs
+++ b/src/login/snapshot.rs
@@ -1,7 +1,8 @@
use super::{
event::{Created, Event},
- Id, Name,
+ Id,
};
+use crate::name::Name;
// This also implements FromRequestParts (see `./extract.rs`). As a result, it
// can be used as an extractor for endpoints that want to require login, or for
diff --git a/src/message/app.rs b/src/message/app.rs
index af87553..852b958 100644
--- a/src/message/app.rs
+++ b/src/message/app.rs
@@ -9,6 +9,7 @@ use crate::{
db::NotFound as _,
event::{repo::Provider as _, Broadcaster, Event, Sequence},
login::Login,
+ name,
};
pub struct Messages<'a> {
@@ -119,6 +120,18 @@ pub enum SendError {
ChannelNotFound(channel::Id),
#[error(transparent)]
Database(#[from] sqlx::Error),
+ #[error(transparent)]
+ Name(#[from] name::Error),
+}
+
+impl From<channel::repo::LoadError> for SendError {
+ fn from(error: channel::repo::LoadError) -> Self {
+ use channel::repo::LoadError;
+ match error {
+ LoadError::Database(error) => error.into(),
+ LoadError::Name(error) => error.into(),
+ }
+ }
}
#[derive(Debug, thiserror::Error)]
diff --git a/src/message/body.rs b/src/message/body.rs
index a415f85..6dd224c 100644
--- a/src/message/body.rs
+++ b/src/message/body.rs
@@ -1,6 +1,6 @@
use std::fmt;
-use crate::nfc;
+use crate::normalize::nfc;
#[derive(
Clone, Debug, Default, Eq, PartialEq, serde::Deserialize, serde::Serialize, sqlx::Type,
diff --git a/src/name.rs b/src/name.rs
new file mode 100644
index 0000000..9187d33
--- /dev/null
+++ b/src/name.rs
@@ -0,0 +1,85 @@
+use std::fmt;
+
+use crate::normalize::{ident, nfc};
+
+#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize, sqlx::Type)]
+#[serde(from = "String", into = "String")]
+pub struct Name {
+ display: nfc::String,
+ canonical: ident::String,
+}
+
+impl Name {
+ pub fn new<D, C>(display: D, canonical: C) -> Result<Self, Error>
+ where
+ D: AsRef<str>,
+ C: AsRef<str>,
+ {
+ let name = Self::from(display);
+
+ if name.canonical.as_str() == canonical.as_ref() {
+ Ok(name)
+ } else {
+ Err(Error::CanonicalMismatch(
+ canonical.as_ref().into(),
+ name.canonical,
+ name.display,
+ ))
+ }
+ }
+
+ pub fn optional<D, C>(display: Option<D>, canonical: Option<C>) -> Result<Option<Self>, Error>
+ where
+ D: AsRef<str>,
+ C: AsRef<str>,
+ {
+ display
+ .zip(canonical)
+ .map(|(display, canonical)| Self::new(display, canonical))
+ .transpose()
+ }
+
+ pub fn display(&self) -> &nfc::String {
+ &self.display
+ }
+
+ pub fn canonical(&self) -> &ident::String {
+ &self.canonical
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error("stored canonical form {0:#?} does not match computed canonical form {:#?} for name {:#?}", .1.as_str(), .2.as_str())]
+ CanonicalMismatch(String, ident::String, nfc::String),
+}
+
+impl Default for Name {
+ fn default() -> Self {
+ Self::from(String::default())
+ }
+}
+
+impl fmt::Display for Name {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.display.fmt(f)
+ }
+}
+
+impl<S> From<S> for Name
+where
+ S: AsRef<str>,
+{
+ fn from(name: S) -> Self {
+ let display = nfc::String::from(&name);
+ let canonical = ident::String::from(&name);
+
+ Self { display, canonical }
+ }
+}
+
+impl From<Name> for String {
+ fn from(name: Name) -> Self {
+ name.display.into()
+ }
+}
diff --git a/src/normalize/mod.rs b/src/normalize/mod.rs
new file mode 100644
index 0000000..6294201
--- /dev/null
+++ b/src/normalize/mod.rs
@@ -0,0 +1,36 @@
+mod string;
+
+pub mod nfc {
+ use std::string::String as StdString;
+
+ use unicode_normalization::UnicodeNormalization as _;
+
+ pub type String = super::string::String<Nfc>;
+
+ #[derive(Clone, Debug, Default, Eq, PartialEq)]
+ pub struct Nfc;
+
+ impl super::string::Normalize for Nfc {
+ fn normalize(&self, value: &str) -> StdString {
+ value.nfc().collect()
+ }
+ }
+}
+
+pub mod ident {
+ use std::string::String as StdString;
+
+ use unicode_casefold::UnicodeCaseFold as _;
+ use unicode_normalization::UnicodeNormalization as _;
+
+ pub type String = super::string::String<Ident>;
+
+ #[derive(Clone, Debug, Default, Eq, PartialEq)]
+ pub struct Ident;
+
+ impl super::string::Normalize for Ident {
+ fn normalize(&self, value: &str) -> StdString {
+ value.case_fold().nfkc().collect()
+ }
+ }
+}
diff --git a/src/nfc.rs b/src/normalize/string.rs
index 70e936c..a0d178c 100644
--- a/src/nfc.rs
+++ b/src/normalize/string.rs
@@ -4,39 +4,48 @@ use sqlx::{
encode::{Encode, IsNull},
Database, Decode, Type,
};
-use unicode_normalization::UnicodeNormalization as _;
-#[derive(Clone, Debug, Default, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
-#[serde(from = "StdString", into = "StdString")]
-pub struct String(StdString);
+pub trait Normalize: Clone + Default {
+ fn normalize(&self, value: &str) -> StdString;
+}
+
+#[derive(Clone, Debug, Default, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
+#[serde(into = "StdString", from = "StdString")]
+#[serde(bound = "N: Normalize")]
+pub struct String<N>(StdString, N);
-impl fmt::Display for String {
+impl<N> fmt::Display for String<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- let Self(value) = self;
+ let Self(value, _) = self;
value.fmt(f)
}
}
-impl From<StdString> for String {
- fn from(value: StdString) -> Self {
- let value = value.nfc().collect();
+impl<S, N> From<S> for String<N>
+where
+ S: AsRef<str>,
+ N: Normalize,
+{
+ fn from(value: S) -> Self {
+ let normalizer = N::default();
+ let value = normalizer.normalize(value.as_ref());
- Self(value)
+ Self(value, normalizer)
}
}
-impl From<String> for StdString {
- fn from(value: String) -> Self {
- let String(value) = value;
+impl<N> From<String<N>> for StdString {
+ fn from(value: String<N>) -> Self {
+ let String(value, _) = value;
value
}
}
-impl std::ops::Deref for String {
+impl<N> std::ops::Deref for String<N> {
type Target = StdString;
fn deref(&self) -> &Self::Target {
- let Self(value) = self;
+ let Self(value, _) = self;
value
}
}
@@ -44,7 +53,7 @@ impl std::ops::Deref for String {
// Type is manually implemented so that we can implement Decode to do
// normalization on read. Implementation is otherwise based on
// `#[derive(sqlx::Type)]` with the `#[sqlx(transparent)]` attribute.
-impl<DB> Type<DB> for String
+impl<DB, N> Type<DB> for String<N>
where
DB: Database,
StdString: Type<DB>,
@@ -58,19 +67,19 @@ where
}
}
-impl<'r, DB> Decode<'r, DB> for String
+impl<'r, DB, N> Decode<'r, DB> for String<N>
where
DB: Database,
StdString: Decode<'r, DB>,
+ N: Normalize,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> {
let value = StdString::decode(value)?;
- let value = value.nfc().collect();
- Ok(Self(value))
+ Ok(Self::from(value))
}
}
-impl<'q, DB> Encode<'q, DB> for String
+impl<'q, DB, N> Encode<'q, DB> for String<N>
where
DB: Database,
StdString: Encode<'q, DB>,
@@ -79,7 +88,7 @@ where
&self,
buf: &mut <DB as Database>::ArgumentBuffer<'q>,
) -> Result<IsNull, sqlx::error::BoxDynError> {
- let Self(value) = self;
+ let Self(value, _) = self;
value.encode_by_ref(buf)
}
@@ -87,17 +96,17 @@ where
self,
buf: &mut <DB as Database>::ArgumentBuffer<'q>,
) -> Result<IsNull, sqlx::error::BoxDynError> {
- let Self(value) = self;
+ let Self(value, _) = self;
value.encode(buf)
}
fn produces(&self) -> Option<<DB as Database>::TypeInfo> {
- let Self(value) = self;
+ let Self(value, _) = self;
value.produces()
}
fn size_hint(&self) -> usize {
- let Self(value) = self;
+ let Self(value, _) = self;
value.size_hint()
}
}
diff --git a/src/setup/app.rs b/src/setup/app.rs
index 9fbcf6d..030b5f6 100644
--- a/src/setup/app.rs
+++ b/src/setup/app.rs
@@ -4,7 +4,8 @@ use super::repo::Provider as _;
use crate::{
clock::DateTime,
event::{repo::Provider as _, Broadcaster, Event},
- login::{repo::Provider as _, Login, Name, Password},
+ login::{repo::Provider as _, Login, Password},
+ name::Name,
token::{repo::Provider as _, Secret},
};
diff --git a/src/setup/routes/post.rs b/src/setup/routes/post.rs
index 6a3fa11..fb2280a 100644
--- a/src/setup/routes/post.rs
+++ b/src/setup/routes/post.rs
@@ -8,7 +8,8 @@ use crate::{
app::App,
clock::RequestedAt,
error::Internal,
- login::{Login, Name, Password},
+ login::{Login, Password},
+ name::Name,
setup::app,
token::extract::IdentityToken,
};
diff --git a/src/test/fixtures/channel.rs b/src/test/fixtures/channel.rs
index 024ac1b..3831c82 100644
--- a/src/test/fixtures/channel.rs
+++ b/src/test/fixtures/channel.rs
@@ -8,9 +8,10 @@ use rand;
use crate::{
app::App,
- channel::{self, Channel, Name},
+ channel::{self, Channel},
clock::RequestedAt,
event::Event,
+ name::Name,
};
pub async fn create(app: &App, created_at: &RequestedAt) -> Channel {
diff --git a/src/test/fixtures/login.rs b/src/test/fixtures/login.rs
index 0a42320..714b936 100644
--- a/src/test/fixtures/login.rs
+++ b/src/test/fixtures/login.rs
@@ -4,7 +4,8 @@ use uuid::Uuid;
use crate::{
app::App,
clock::RequestedAt,
- login::{self, Login, Name, Password},
+ login::{self, Login, Password},
+ name::Name,
};
pub async fn create_with_password(app: &App, created_at: &RequestedAt) -> (Login, Password) {
diff --git a/src/token/app.rs b/src/token/app.rs
index d4dd1a0..c19d6a0 100644
--- a/src/token/app.rs
+++ b/src/token/app.rs
@@ -7,12 +7,14 @@ use futures::{
use sqlx::sqlite::SqlitePool;
use super::{
- repo::auth::Provider as _, repo::Provider as _, Broadcaster, Event as TokenEvent, Id, Secret,
+ repo::{self, auth::Provider as _, Provider as _},
+ Broadcaster, Event as TokenEvent, Id, Secret,
};
use crate::{
clock::DateTime,
db::NotFound as _,
- login::{Login, Name, Password},
+ login::{Login, Password},
+ name::{self, Name},
};
pub struct Tokens<'a> {
@@ -65,14 +67,16 @@ impl<'a> Tokens<'a> {
used_at: &DateTime,
) -> Result<(Id, Login), ValidateError> {
let mut tx = self.db.begin().await?;
- let login = tx
+ let (token, login) = tx
.tokens()
.validate(secret, used_at)
.await
.not_found(|| ValidateError::InvalidToken)?;
tx.commit().await?;
- Ok(login)
+ let login = login.as_snapshot().ok_or(ValidateError::LoginDeleted)?;
+
+ Ok((token, login))
}
pub async fn limit_stream<E>(
@@ -162,15 +166,40 @@ pub enum LoginError {
#[error(transparent)]
Database(#[from] sqlx::Error),
#[error(transparent)]
+ Name(#[from] name::Error),
+ #[error(transparent)]
PasswordHash(#[from] password_hash::Error),
}
+impl From<repo::auth::LoadError> for LoginError {
+ fn from(error: repo::auth::LoadError) -> Self {
+ use repo::auth::LoadError;
+ match error {
+ LoadError::Database(error) => error.into(),
+ LoadError::Name(error) => error.into(),
+ }
+ }
+}
+
#[derive(Debug, thiserror::Error)]
pub enum ValidateError {
#[error("invalid token")]
InvalidToken,
+ #[error("login deleted")]
+ LoginDeleted,
#[error(transparent)]
Database(#[from] sqlx::Error),
+ #[error(transparent)]
+ Name(#[from] name::Error),
+}
+
+impl From<repo::LoadError> for ValidateError {
+ fn from(error: repo::LoadError) -> Self {
+ match error {
+ repo::LoadError::Database(error) => error.into(),
+ repo::LoadError::Name(error) => error.into(),
+ }
+ }
}
#[derive(Debug)]
diff --git a/src/token/repo/auth.rs b/src/token/repo/auth.rs
index c621b65..bdc4c33 100644
--- a/src/token/repo/auth.rs
+++ b/src/token/repo/auth.rs
@@ -2,8 +2,10 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction};
use crate::{
clock::DateTime,
+ db::NotFound,
event::{Instant, Sequence},
- login::{self, password::StoredHash, History, Login, Name},
+ login::{self, password::StoredHash, History, Login},
+ name::{self, Name},
};
pub trait Provider {
@@ -19,35 +21,53 @@ impl<'c> Provider for Transaction<'c, Sqlite> {
pub struct Auth<'t>(&'t mut SqliteConnection);
impl<'t> Auth<'t> {
- pub async fn for_name(&mut self, name: &Name) -> Result<(History, StoredHash), sqlx::Error> {
- let found = sqlx::query!(
+ pub async fn for_name(&mut self, name: &Name) -> Result<(History, StoredHash), LoadError> {
+ let name = name.canonical();
+ let row = sqlx::query!(
r#"
- select
- id as "id: login::Id",
- name as "name: Name",
- password_hash as "password_hash: StoredHash",
+ select
+ id as "id: login::Id",
+ display_name as "display_name: String",
+ canonical_name as "canonical_name: String",
created_sequence as "created_sequence: Sequence",
- created_at as "created_at: DateTime"
- from login
- where name = $1
- "#,
+ created_at as "created_at: DateTime",
+ password_hash as "password_hash: StoredHash"
+ from login
+ where canonical_name = $1
+ "#,
name,
)
- .map(|row| {
- (
- History {
- login: Login {
- id: row.id,
- name: row.name,
- },
- created: Instant::new(row.created_at, row.created_sequence),
- },
- row.password_hash,
- )
- })
.fetch_one(&mut *self.0)
.await?;
- Ok(found)
+ let login = History {
+ login: Login {
+ id: row.id,
+ name: Name::new(row.display_name, row.canonical_name)?,
+ },
+ created: Instant::new(row.created_at, row.created_sequence),
+ };
+
+ Ok((login, row.password_hash))
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub enum LoadError {
+ Database(#[from] sqlx::Error),
+ Name(#[from] name::Error),
+}
+
+impl<T> NotFound for Result<T, LoadError> {
+ type Ok = T;
+ type Error = LoadError;
+
+ fn optional(self) -> Result<Option<T>, LoadError> {
+ match self {
+ Ok(value) => Ok(Some(value)),
+ Err(LoadError::Database(sqlx::Error::RowNotFound)) => Ok(None),
+ Err(other) => Err(other),
+ }
}
}
diff --git a/src/token/repo/mod.rs b/src/token/repo/mod.rs
index 9169743..d8463eb 100644
--- a/src/token/repo/mod.rs
+++ b/src/token/repo/mod.rs
@@ -1,4 +1,4 @@
pub mod auth;
mod token;
-pub use self::token::Provider;
+pub use self::token::{LoadError, Provider};
diff --git a/src/token/repo/token.rs b/src/token/repo/token.rs
index 960bb72..35ea385 100644
--- a/src/token/repo/token.rs
+++ b/src/token/repo/token.rs
@@ -3,7 +3,10 @@ use uuid::Uuid;
use crate::{
clock::DateTime,
- login::{self, History, Login, Name},
+ db::NotFound,
+ event::{Instant, Sequence},
+ login::{self, History, Login},
+ name::{self, Name},
token::{Id, Secret},
};
@@ -100,53 +103,78 @@ impl<'c> Tokens<'c> {
}
// Validate a token by its secret, retrieving the associated Login record.
- // Will return [None] if the token is not valid. The token's last-used
- // timestamp will be set to `used_at`.
+ // Will return an error if the token is not valid. If successful, the
+ // retrieved token's last-used timestamp will be set to `used_at`.
pub async fn validate(
&mut self,
secret: &Secret,
used_at: &DateTime,
- ) -> Result<(Id, Login), sqlx::Error> {
+ ) -> Result<(Id, History), LoadError> {
// I would use `update … returning` to do this in one query, but
// sqlite3, as of this writing, does not allow an update's `returning`
// clause to reference columns from tables joined into the update. Two
// queries is fine, but it feels untidy.
- sqlx::query!(
+ let (token, login) = sqlx::query!(
r#"
update token
set last_used_at = $1
where secret = $2
+ returning
+ id as "token: Id",
+ login as "login: login::Id"
"#,
used_at,
secret,
)
- .execute(&mut *self.0)
+ .map(|row| (row.token, row.login))
+ .fetch_one(&mut *self.0)
.await?;
let login = sqlx::query!(
r#"
select
- token.id as "token_id: Id",
- login.id as "login_id: login::Id",
- login.name as "login_name: Name"
+ id as "id: login::Id",
+ display_name as "display_name: String",
+ canonical_name as "canonical_name: String",
+ created_sequence as "created_sequence: Sequence",
+ created_at as "created_at: DateTime"
from login
- join token on login.id = token.login
- where token.secret = $1
+ where id = $1
"#,
- secret,
+ login,
)
.map(|row| {
- (
- row.token_id,
- Login {
- id: row.login_id,
- name: row.login_name,
+ Ok::<_, name::Error>(History {
+ login: Login {
+ id: row.id,
+ name: Name::new(row.display_name, row.canonical_name)?,
},
- )
+ created: Instant::new(row.created_at, row.created_sequence),
+ })
})
.fetch_one(&mut *self.0)
- .await?;
+ .await??;
+
+ Ok((token, login))
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error(transparent)]
+pub enum LoadError {
+ Database(#[from] sqlx::Error),
+ Name(#[from] name::Error),
+}
+
+impl<T> NotFound for Result<T, LoadError> {
+ type Ok = T;
+ type Error = LoadError;
- Ok(login)
+ fn optional(self) -> Result<Option<T>, LoadError> {
+ match self {
+ Ok(value) => Ok(Some(value)),
+ Err(LoadError::Database(sqlx::Error::RowNotFound)) => Ok(None),
+ Err(other) => Err(other),
+ }
}
}