diff options
| author | Owen Jacobson <owen@grimoire.ca> | 2024-10-22 19:12:34 -0400 |
|---|---|---|
| committer | Owen Jacobson <owen@grimoire.ca> | 2024-10-22 19:12:34 -0400 |
| commit | 6430854352745f45281021c305b4e350bc92d535 (patch) | |
| tree | c6901c22a45e36415f63efe988d4d4f2a309df81 /src | |
| parent | 98af8ff80da919a1126ba7c6afa65e6654b5ecde (diff) | |
| parent | db940bacd096a33a65f29759e70ea1acf6186a67 (diff) | |
Merge branch 'unicode-normalization'
Diffstat (limited to 'src')
43 files changed, 970 insertions, 245 deletions
@@ -5,14 +5,12 @@ use crate::{ channel::app::Channels, event::{self, app::Events}, invite::app::Invites, + login::app::Logins, message::app::Messages, setup::app::Setup, token::{self, app::Tokens}, }; -#[cfg(test)] -use crate::login::app::Logins; - #[derive(Clone)] pub struct App { db: SqlitePool, @@ -49,6 +47,11 @@ impl App { Invites::new(&self.db) } + #[cfg(not(test))] + pub const fn logins(&self) -> Logins { + Logins::new(&self.db) + } + #[cfg(test)] pub const fn logins(&self) -> Logins { Logins::new(&self.db, &self.events) diff --git a/src/bin/hi-recanonicalize.rs b/src/bin/hi-recanonicalize.rs new file mode 100644 index 0000000..4081276 --- /dev/null +++ b/src/bin/hi-recanonicalize.rs @@ -0,0 +1,9 @@ +use clap::Parser; + +use hi::cli; + +#[tokio::main] +async fn main() -> Result<(), cli::recanonicalize::Error> { + let args = cli::recanonicalize::Args::parse(); + args.run().await +} diff --git a/src/main.rs b/src/bin/hi.rs index d0830ff..d0830ff 100644 --- a/src/main.rs +++ b/src/bin/hi.rs diff --git a/src/boot/app.rs b/src/boot/app.rs index ef48b2f..1d88608 100644 --- a/src/boot/app.rs +++ b/src/boot/app.rs @@ -2,8 +2,11 @@ use sqlx::sqlite::SqlitePool; use super::Snapshot; use crate::{ - channel::repo::Provider as _, event::repo::Provider as _, login::repo::Provider as _, + channel::{self, repo::Provider as _}, + event::repo::Provider as _, + login::{self, repo::Provider as _}, message::repo::Provider as _, + name, }; pub struct Boot<'a> { @@ -15,7 +18,7 @@ impl<'a> Boot<'a> { Self { db } } - pub async fn snapshot(&self) -> Result<Snapshot, sqlx::Error> { + pub async fn snapshot(&self) -> Result<Snapshot, Error> { let mut tx = self.db.begin().await?; let resume_point = tx.sequence().current().await?; @@ -48,3 +51,30 @@ impl<'a> Boot<'a> { }) } } + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum Error { + Name(#[from] name::Error), + Database(#[from] sqlx::Error), +} + +impl From<login::repo::LoadError> for Error { + fn from(error: login::repo::LoadError) -> Self { + use login::repo::LoadError; + match error { + LoadError::Name(error) => error.into(), + LoadError::Database(error) => error.into(), + } + } +} + +impl From<channel::repo::LoadError> for Error { + fn from(error: channel::repo::LoadError) -> Self { + use channel::repo::LoadError; + match error { + LoadError::Name(error) => error.into(), + LoadError::Database(error) => error.into(), + } + } +} diff --git a/src/channel/app.rs b/src/channel/app.rs index 75c662d..7bfa0f7 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -2,12 +2,16 @@ use chrono::TimeDelta; use itertools::Itertools; use sqlx::sqlite::SqlitePool; -use super::{repo::Provider as _, Channel, History, Id}; +use super::{ + repo::{LoadError, Provider as _}, + Channel, History, Id, +}; use crate::{ clock::DateTime, db::{Duplicate as _, NotFound as _}, event::{repo::Provider as _, Broadcaster, Event, Sequence}, message::repo::Provider as _, + name::{self, Name}, }; pub struct Channels<'a> { @@ -20,14 +24,14 @@ impl<'a> Channels<'a> { Self { db, events } } - pub async fn create(&self, name: &str, created_at: &DateTime) -> Result<Channel, CreateError> { + pub async fn create(&self, name: &Name, created_at: &DateTime) -> Result<Channel, CreateError> { let mut tx = self.db.begin().await?; let created = tx.sequence().next(created_at).await?; let channel = tx .channels() .create(name, &created) .await - .duplicate(|| CreateError::DuplicateName(name.into()))?; + .duplicate(|| CreateError::DuplicateName(name.clone()))?; tx.commit().await?; self.events @@ -38,7 +42,7 @@ impl<'a> Channels<'a> { // This function is careless with respect to time, and gets you the channel as // it exists in the specific moment when you call it. - pub async fn get(&self, channel: &Id) -> Result<Option<Channel>, sqlx::Error> { + pub async fn get(&self, channel: &Id) -> Result<Option<Channel>, Error> { let mut tx = self.db.begin().await?; let channel = tx.channels().by_id(channel).await.optional()?; tx.commit().await?; @@ -88,7 +92,7 @@ impl<'a> Channels<'a> { Ok(()) } - pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), ExpireError> { // Somewhat arbitrarily, expire after 90 days. let expire_at = relative_to.to_owned() - TimeDelta::days(90); @@ -129,14 +133,33 @@ impl<'a> Channels<'a> { Ok(()) } + + pub async fn recanonicalize(&self) -> Result<(), sqlx::Error> { + let mut tx = self.db.begin().await?; + tx.channels().recanonicalize().await?; + tx.commit().await?; + + Ok(()) + } } #[derive(Debug, thiserror::Error)] pub enum CreateError { #[error("channel named {0} already exists")] - DuplicateName(String), + DuplicateName(Name), #[error(transparent)] Database(#[from] sqlx::Error), + #[error(transparent)] + Name(#[from] name::Error), +} + +impl From<LoadError> for CreateError { + fn from(error: LoadError) -> Self { + match error { + LoadError::Database(error) => error.into(), + LoadError::Name(error) => error.into(), + } + } } #[derive(Debug, thiserror::Error)] @@ -147,4 +170,32 @@ pub enum Error { Deleted(Id), #[error(transparent)] Database(#[from] sqlx::Error), + #[error(transparent)] + Name(#[from] name::Error), +} + +impl From<LoadError> for Error { + fn from(error: LoadError) -> Self { + match error { + LoadError::Database(error) => error.into(), + LoadError::Name(error) => error.into(), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ExpireError { + #[error(transparent)] + Database(#[from] sqlx::Error), + #[error(transparent)] + Name(#[from] name::Error), +} + +impl From<LoadError> for ExpireError { + fn from(error: LoadError) -> Self { + match error { + LoadError::Database(error) => error.into(), + LoadError::Name(error) => error.into(), + } + } } diff --git a/src/channel/repo.rs b/src/channel/repo.rs index 27d35f0..e26ac2b 100644 --- a/src/channel/repo.rs +++ b/src/channel/repo.rs @@ -1,9 +1,12 @@ +use futures::stream::{StreamExt as _, TryStreamExt as _}; use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ channel::{Channel, History, Id}, clock::DateTime, + db::NotFound, event::{Instant, ResumePoint, Sequence}, + name::{self, Name}, }; pub trait Provider { @@ -19,134 +22,162 @@ impl<'c> Provider for Transaction<'c, Sqlite> { pub struct Channels<'t>(&'t mut SqliteConnection); impl<'c> Channels<'c> { - pub async fn create(&mut self, name: &str, created: &Instant) -> Result<History, sqlx::Error> { + pub async fn create(&mut self, name: &Name, created: &Instant) -> Result<History, sqlx::Error> { let id = Id::generate(); - let channel = sqlx::query!( + let name = name.clone(); + let display_name = name.display(); + let canonical_name = name.canonical(); + let created = *created; + + sqlx::query!( r#" insert - into channel (id, name, created_at, created_sequence) - values ($1, $2, $3, $4) - returning - id as "id: Id", - name as "name!", -- known non-null as we just set it - created_at as "created_at: DateTime", - created_sequence as "created_sequence: Sequence" + into channel (id, created_at, created_sequence) + values ($1, $2, $3) "#, id, - name, created.at, created.sequence, ) - .map(|row| History { + .execute(&mut *self.0) + .await?; + + sqlx::query!( + r#" + insert into channel_name (id, display_name, canonical_name) + values ($1, $2, $3) + "#, + id, + display_name, + canonical_name, + ) + .execute(&mut *self.0) + .await?; + + let channel = History { channel: Channel { - id: row.id, - name: row.name, + id, + name: name.clone(), deleted_at: None, }, - created: Instant::new(row.created_at, row.created_sequence), + created, deleted: None, - }) - .fetch_one(&mut *self.0) - .await?; + }; Ok(channel) } - pub async fn by_id(&mut self, channel: &Id) -> Result<History, sqlx::Error> { + pub async fn by_id(&mut self, channel: &Id) -> Result<History, LoadError> { let channel = sqlx::query!( r#" select id as "id: Id", - channel.name, + name.display_name as "display_name?: String", + name.canonical_name as "canonical_name?: String", channel.created_at as "created_at: DateTime", channel.created_sequence as "created_sequence: Sequence", deleted.deleted_at as "deleted_at?: DateTime", deleted.deleted_sequence as "deleted_sequence?: Sequence" from channel + left join channel_name as name + using (id) left join channel_deleted as deleted using (id) where id = $1 "#, channel, ) - .map(|row| History { - channel: Channel { - id: row.id, - name: row.name.unwrap_or_default(), - deleted_at: row.deleted_at, - }, - created: Instant::new(row.created_at, row.created_sequence), - deleted: Instant::optional(row.deleted_at, row.deleted_sequence), + .map(|row| { + Ok::<_, name::Error>(History { + channel: Channel { + id: row.id, + name: Name::optional(row.display_name, row.canonical_name)?.unwrap_or_default(), + deleted_at: row.deleted_at, + }, + created: Instant::new(row.created_at, row.created_sequence), + deleted: Instant::optional(row.deleted_at, row.deleted_sequence), + }) }) .fetch_one(&mut *self.0) - .await?; + .await??; Ok(channel) } - pub async fn all(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, sqlx::Error> { + pub async fn all(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, LoadError> { let channels = sqlx::query!( r#" select id as "id: Id", - channel.name, + name.display_name as "display_name: String", + name.canonical_name as "canonical_name: String", channel.created_at as "created_at: DateTime", channel.created_sequence as "created_sequence: Sequence", - deleted.deleted_at as "deleted_at: DateTime", - deleted.deleted_sequence as "deleted_sequence: Sequence" + deleted.deleted_at as "deleted_at?: DateTime", + deleted.deleted_sequence as "deleted_sequence?: Sequence" from channel + left join channel_name as name + using (id) left join channel_deleted as deleted using (id) where coalesce(channel.created_sequence <= $1, true) - order by channel.name + order by name.canonical_name "#, resume_at, ) - .map(|row| History { - channel: Channel { - id: row.id, - name: row.name.unwrap_or_default(), - deleted_at: row.deleted_at, - }, - created: Instant::new(row.created_at, row.created_sequence), - deleted: Instant::optional(row.deleted_at, row.deleted_sequence), + .map(|row| { + Ok::<_, name::Error>(History { + channel: Channel { + id: row.id, + name: Name::optional(row.display_name, row.canonical_name)?.unwrap_or_default(), + deleted_at: row.deleted_at, + }, + created: Instant::new(row.created_at, row.created_sequence), + deleted: Instant::optional(row.deleted_at, row.deleted_sequence), + }) }) - .fetch_all(&mut *self.0) + .fetch(&mut *self.0) + .map(|res| Ok::<_, LoadError>(res??)) + .try_collect() .await?; Ok(channels) } - pub async fn replay( - &mut self, - resume_at: Option<Sequence>, - ) -> Result<Vec<History>, sqlx::Error> { + pub async fn replay(&mut self, resume_at: Option<Sequence>) -> Result<Vec<History>, LoadError> { let channels = sqlx::query!( r#" select id as "id: Id", - channel.name, + name.display_name as "display_name: String", + name.canonical_name as "canonical_name: String", channel.created_at as "created_at: DateTime", channel.created_sequence as "created_sequence: Sequence", - deleted.deleted_at as "deleted_at: DateTime", - deleted.deleted_sequence as "deleted_sequence: Sequence" + deleted.deleted_at as "deleted_at?: DateTime", + deleted.deleted_sequence as "deleted_sequence?: Sequence" from channel + left join channel_name as name + using (id) left join channel_deleted as deleted using (id) where coalesce(channel.created_sequence > $1, true) "#, resume_at, ) - .map(|row| History { - channel: Channel { - id: row.id, - name: row.name.unwrap_or_default(), - deleted_at: row.deleted_at, - }, - created: Instant::new(row.created_at, row.created_sequence), - deleted: Instant::optional(row.deleted_at, row.deleted_sequence), + .map(|row| { + Ok::<_, name::Error>(History { + channel: Channel { + id: row.id, + name: Name::optional(row.display_name, row.canonical_name)?.unwrap_or_default(), + deleted_at: row.deleted_at, + }, + created: Instant::new(row.created_at, row.created_sequence), + deleted: Instant::optional(row.deleted_at, row.deleted_sequence), + }) }) - .fetch_all(&mut *self.0) + .fetch(&mut *self.0) + .map(|res| Ok::<_, LoadError>(res??)) + .try_collect() .await?; Ok(channels) @@ -156,19 +187,18 @@ impl<'c> Channels<'c> { &mut self, channel: &History, deleted: &Instant, - ) -> Result<History, sqlx::Error> { + ) -> Result<History, LoadError> { let id = channel.id(); - sqlx::query_scalar!( + sqlx::query!( r#" insert into channel_deleted (id, deleted_at, deleted_sequence) values ($1, $2, $3) - returning 1 as "deleted: bool" "#, id, deleted.at, deleted.sequence, ) - .fetch_one(&mut *self.0) + .execute(&mut *self.0) .await?; // Small social responsibility hack here: when a channel is deleted, its name is @@ -179,16 +209,14 @@ impl<'c> Channels<'c> { // This also avoids the need for a separate name reservation table to ensure // that live channels have unique names, since the `channel` table's name field // is unique over non-null values. - sqlx::query_scalar!( + sqlx::query!( r#" - update channel - set name = null + delete from channel_name where id = $1 - returning 1 as "updated: bool" "#, id, ) - .fetch_one(&mut *self.0) + .execute(&mut *self.0) .await?; let channel = self.by_id(id).await?; @@ -230,38 +258,98 @@ impl<'c> Channels<'c> { Ok(()) } - pub async fn expired(&mut self, expired_at: &DateTime) -> Result<Vec<History>, sqlx::Error> { + pub async fn expired(&mut self, expired_at: &DateTime) -> Result<Vec<History>, LoadError> { let channels = sqlx::query!( r#" select channel.id as "id: Id", - channel.name, + name.display_name as "display_name: String", + name.canonical_name as "canonical_name: String", channel.created_at as "created_at: DateTime", channel.created_sequence as "created_sequence: Sequence", deleted.deleted_at as "deleted_at?: DateTime", deleted.deleted_sequence as "deleted_sequence?: Sequence" from channel + left join channel_name as name + using (id) left join channel_deleted as deleted using (id) left join message + on channel.id = message.channel where channel.created_at < $1 and message.id is null and deleted.id is null "#, expired_at, ) - .map(|row| History { - channel: Channel { - id: row.id, - name: row.name.unwrap_or_default(), - deleted_at: row.deleted_at, - }, - created: Instant::new(row.created_at, row.created_sequence), - deleted: Instant::optional(row.deleted_at, row.deleted_sequence), + .map(|row| { + Ok::<_, name::Error>(History { + channel: Channel { + id: row.id, + name: Name::optional(row.display_name, row.canonical_name)?.unwrap_or_default(), + deleted_at: row.deleted_at, + }, + created: Instant::new(row.created_at, row.created_sequence), + deleted: Instant::optional(row.deleted_at, row.deleted_sequence), + }) }) - .fetch_all(&mut *self.0) + .fetch(&mut *self.0) + .map(|res| Ok::<_, LoadError>(res??)) + .try_collect() .await?; Ok(channels) } + + pub async fn recanonicalize(&mut self) -> Result<(), sqlx::Error> { + let channels = sqlx::query!( + r#" + select + id as "id: Id", + display_name as "display_name: String" + from channel_name + "#, + ) + .fetch_all(&mut *self.0) + .await?; + + for channel in channels { + let name = Name::from(channel.display_name); + let canonical_name = name.canonical(); + + sqlx::query!( + r#" + update channel_name + set canonical_name = $1 + where id = $2 + "#, + canonical_name, + channel.id, + ) + .execute(&mut *self.0) + .await?; + } + + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum LoadError { + Database(#[from] sqlx::Error), + Name(#[from] name::Error), +} + +impl<T> NotFound for Result<T, LoadError> { + type Ok = T; + type Error = LoadError; + + fn optional(self) -> Result<Option<T>, LoadError> { + match self { + Ok(value) => Ok(Some(value)), + Err(LoadError::Database(sqlx::Error::RowNotFound)) => Ok(None), + Err(other) => Err(other), + } + } } diff --git a/src/channel/routes/channel/post.rs b/src/channel/routes/channel/post.rs index b489a77..d0cae05 100644 --- a/src/channel/routes/channel/post.rs +++ b/src/channel/routes/channel/post.rs @@ -9,7 +9,7 @@ use crate::{ clock::RequestedAt, error::{Internal, NotFound}, login::Login, - message::{app::SendError, Message}, + message::{app::SendError, Body, Message}, }; pub async fn handler( @@ -29,7 +29,7 @@ pub async fn handler( #[derive(serde::Deserialize)] pub struct Request { - pub body: String, + pub body: Body, } #[derive(Debug)] diff --git a/src/channel/routes/post.rs b/src/channel/routes/post.rs index a05c312..9781dd7 100644 --- a/src/channel/routes/post.rs +++ b/src/channel/routes/post.rs @@ -10,6 +10,7 @@ use crate::{ clock::RequestedAt, error::Internal, login::Login, + name::Name, }; pub async fn handler( @@ -29,7 +30,7 @@ pub async fn handler( #[derive(serde::Deserialize)] pub struct Request { - pub name: String, + pub name: Name, } #[derive(Debug)] diff --git a/src/channel/snapshot.rs b/src/channel/snapshot.rs index 2b7d89a..129c0d6 100644 --- a/src/channel/snapshot.rs +++ b/src/channel/snapshot.rs @@ -2,12 +2,12 @@ use super::{ event::{Created, Event}, Id, }; -use crate::clock::DateTime; +use crate::{clock::DateTime, name::Name}; #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct Channel { pub id: Id, - pub name: String, + pub name: Name, #[serde(skip_serializing_if = "Option::is_none")] pub deleted_at: Option<DateTime>, } diff --git a/src/cli.rs b/src/cli/mod.rs index 0659851..c75ce2b 100644 --- a/src/cli.rs +++ b/src/cli/mod.rs @@ -22,6 +22,8 @@ use crate::{ ui, }; +pub mod recanonicalize; + /// Command-line entry point for running the `hi` server. /// /// This is intended to be used as a Clap [Parser], to capture command-line diff --git a/src/cli/recanonicalize.rs b/src/cli/recanonicalize.rs new file mode 100644 index 0000000..5f8a1db --- /dev/null +++ b/src/cli/recanonicalize.rs @@ -0,0 +1,86 @@ +use sqlx::sqlite::SqlitePool; + +use crate::{app::App, db}; + +/// Command-line entry point for repairing canonical names in the `hi` database. +/// This command may be necessary after an upgrade, if the canonical forms of +/// names has changed. It will re-calculate the canonical form of each name in +/// the database, based on its display form, and store the results back to the +/// database. +/// +/// This is intended to be used as a Clap [Parser], to capture command-line +/// arguments for the `hi-recanonicalize` command: +/// +/// ```no_run +/// # use hi::recanonicalize::cli::Error; +/// # +/// # #[tokio::main] +/// # async fn main() -> Result<(), Error> { +/// use clap::Parser; +/// use hi::cli::recanonicalize::Args; +/// +/// let args = Args::parse(); +/// args.run().await?; +/// # Ok(()) +/// # } +/// ``` +#[derive(clap::Parser)] +#[command( + version, + about = "Recanonicalize names in the `hi` database.", + long_about = r#"Recanonicalize names in the `hi` database. + +The `hi` server must not be running while this command is run. + +The database at `--database-url` will also be created, or upgraded, automatically."# +)] +pub struct Args { + /// Sqlite URL or path for the `hi` database + #[arg(short, long, env, default_value = "sqlite://.hi")] + database_url: String, + + /// Sqlite URL or path for a backup of the `hi` database during upgrades + #[arg(short = 'D', long, env, default_value = "sqlite://.hi.backup")] + backup_database_url: String, +} + +impl Args { + /// Recanonicalizes the `hi` database, using the parsed configuation in + /// `self`. + /// + /// This will perform the following tasks: + /// + /// * Migrate the `hi` database (at `--database-url`). + /// * Recanonicalize names in the `login` and `channel` tables. + /// + /// # Errors + /// + /// Will return `Err` if the canonicalization or database upgrade processes + /// fail. The specific [`Error`] variant will expose the cause + /// of the failure. + pub async fn run(self) -> Result<(), Error> { + let pool = self.pool().await?; + + let app = App::from(pool); + app.logins().recanonicalize().await?; + app.channels().recanonicalize().await?; + + Ok(()) + } + + async fn pool(&self) -> Result<SqlitePool, db::Error> { + db::prepare(&self.database_url, &self.backup_database_url).await + } +} + +/// Errors that can be raised by [`Args::run`]. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum Error { + // /// Failure due to `io::Error`. See [`io::Error`]. + // Io(#[from] io::Error), + /// Failure due to a database initialization error. See [`db::Error`]. + Database(#[from] db::Error), + /// Failure due to a data manipulation error. See [`sqlx::Error`]. + Sqlx(#[from] sqlx::Error), +} diff --git a/src/db/mod.rs b/src/db/mod.rs index 6005813..e0522d4 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -130,14 +130,17 @@ pub enum Error { Rejected(String, String), } -pub trait NotFound { +pub trait NotFound: Sized { type Ok; type Error; fn not_found<E, F>(self, map: F) -> Result<Self::Ok, E> where E: From<Self::Error>, - F: FnOnce() -> E; + F: FnOnce() -> E, + { + self.optional()?.ok_or_else(map) + } fn optional(self) -> Result<Option<Self::Ok>, Self::Error>; } @@ -153,14 +156,6 @@ impl<T> NotFound for Result<T, sqlx::Error> { Err(other) => Err(other), } } - - fn not_found<E, F>(self, map: F) -> Result<T, E> - where - E: From<sqlx::Error>, - F: FnOnce() -> E, - { - self.optional()?.ok_or_else(map) - } } pub trait Duplicate { diff --git a/src/event/app.rs b/src/event/app.rs index 951ce25..c754388 100644 --- a/src/event/app.rs +++ b/src/event/app.rs @@ -11,6 +11,7 @@ use crate::{ channel::{self, repo::Provider as _}, login::{self, repo::Provider as _}, message::{self, repo::Provider as _}, + name, }; pub struct Events<'a> { @@ -26,7 +27,7 @@ impl<'a> Events<'a> { pub async fn subscribe( &self, resume_at: impl Into<ResumePoint>, - ) -> Result<impl Stream<Item = Event> + std::fmt::Debug, sqlx::Error> { + ) -> Result<impl Stream<Item = Event> + std::fmt::Debug, Error> { let resume_at = resume_at.into(); // Subscribe before retrieving, to catch messages broadcast while we're // querying the DB. We'll prune out duplicates later. @@ -81,3 +82,30 @@ impl<'a> Events<'a> { move |event| future::ready(filter(event)) } } + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum Error { + Database(#[from] sqlx::Error), + Name(#[from] name::Error), +} + +impl From<login::repo::LoadError> for Error { + fn from(error: login::repo::LoadError) -> Self { + use login::repo::LoadError; + match error { + LoadError::Database(error) => error.into(), + LoadError::Name(error) => error.into(), + } + } +} + +impl From<channel::repo::LoadError> for Error { + fn from(error: channel::repo::LoadError) -> Self { + use channel::repo::LoadError; + match error { + LoadError::Database(error) => error.into(), + LoadError::Name(error) => error.into(), + } + } +} diff --git a/src/event/routes/get.rs b/src/event/routes/get.rs index 357845a..22e8762 100644 --- a/src/event/routes/get.rs +++ b/src/event/routes/get.rs @@ -12,7 +12,7 @@ use futures::stream::{Stream, StreamExt as _}; use crate::{ app::App, error::{Internal, Unauthorized}, - event::{extract::LastEventId, Event, ResumePoint, Sequence, Sequenced as _}, + event::{app, extract::LastEventId, Event, ResumePoint, Sequence, Sequenced as _}, token::{app::ValidateError, extract::Identity}, }; @@ -69,7 +69,7 @@ impl TryFrom<Event> for sse::Event { #[derive(Debug, thiserror::Error)] #[error(transparent)] pub enum Error { - Database(#[from] sqlx::Error), + Subscribe(#[from] app::Error), Validate(#[from] ValidateError), } diff --git a/src/invite/app.rs b/src/invite/app.rs index ee7f74f..64ba753 100644 --- a/src/invite/app.rs +++ b/src/invite/app.rs @@ -7,6 +7,7 @@ use crate::{ db::{Duplicate as _, NotFound as _}, event::repo::Provider as _, login::{repo::Provider as _, Login, Password}, + name::Name, token::{repo::Provider as _, Secret}, }; @@ -42,7 +43,7 @@ impl<'a> Invites<'a> { pub async fn accept( &self, invite: &Id, - name: &str, + name: &Name, password: &Password, accepted_at: &DateTime, ) -> Result<(Login, Secret), AcceptError> { @@ -68,7 +69,7 @@ impl<'a> Invites<'a> { .logins() .create(name, &password_hash, &created) .await - .duplicate(|| AcceptError::DuplicateLogin(name.into()))?; + .duplicate(|| AcceptError::DuplicateLogin(name.clone()))?; let secret = tx.tokens().issue(&login, accepted_at).await?; tx.commit().await?; @@ -92,7 +93,7 @@ pub enum AcceptError { #[error("invite not found: {0}")] NotFound(Id), #[error("name in use: {0}")] - DuplicateLogin(String), + DuplicateLogin(Name), #[error(transparent)] Database(#[from] sqlx::Error), #[error(transparent)] diff --git a/src/invite/mod.rs b/src/invite/mod.rs index abf1c3a..d59fb9c 100644 --- a/src/invite/mod.rs +++ b/src/invite/mod.rs @@ -3,7 +3,7 @@ mod id; mod repo; mod routes; -use crate::{clock::DateTime, login}; +use crate::{clock::DateTime, login, normalize::nfc}; pub use self::{id::Id, routes::router}; @@ -17,6 +17,6 @@ pub struct Invite { #[derive(serde::Serialize)] pub struct Summary { pub id: Id, - pub issuer: String, + pub issuer: nfc::String, pub issued_at: DateTime, } diff --git a/src/invite/repo.rs b/src/invite/repo.rs index 643f5b7..02f4e42 100644 --- a/src/invite/repo.rs +++ b/src/invite/repo.rs @@ -4,6 +4,7 @@ use super::{Id, Invite, Summary}; use crate::{ clock::DateTime, login::{self, Login}, + normalize::nfc, }; pub trait Provider { @@ -70,7 +71,7 @@ impl<'c> Invites<'c> { select invite.id as "invite_id: Id", issuer.id as "issuer_id: login::Id", - issuer.name as "issuer_name", + issuer.display_name as "issuer_name: nfc::String", invite.issued_at as "invite_issued_at: DateTime" from invite join login as issuer on (invite.issuer = issuer.id) diff --git a/src/invite/routes/invite/post.rs b/src/invite/routes/invite/post.rs index c072929..a41207a 100644 --- a/src/invite/routes/invite/post.rs +++ b/src/invite/routes/invite/post.rs @@ -10,6 +10,7 @@ use crate::{ error::{Internal, NotFound}, invite::app, login::{Login, Password}, + name::Name, token::extract::IdentityToken, }; @@ -31,7 +32,7 @@ pub async fn handler( #[derive(serde::Deserialize)] pub struct Request { - pub name: String, + pub name: Name, pub password: Password, } @@ -16,6 +16,8 @@ mod id; mod invite; mod login; mod message; +mod name; +mod normalize; mod setup; #[cfg(test)] mod test; diff --git a/src/login/app.rs b/src/login/app.rs index b6f7e1c..2f5896f 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -1,24 +1,37 @@ use sqlx::sqlite::SqlitePool; -use super::{repo::Provider as _, Login, Password}; +use super::repo::Provider as _; + +#[cfg(test)] +use super::{Login, Password}; +#[cfg(test)] use crate::{ clock::DateTime, event::{repo::Provider as _, Broadcaster, Event}, + name::Name, }; pub struct Logins<'a> { db: &'a SqlitePool, + #[cfg(test)] events: &'a Broadcaster, } impl<'a> Logins<'a> { + #[cfg(not(test))] + pub const fn new(db: &'a SqlitePool) -> Self { + Self { db } + } + + #[cfg(test)] pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { Self { db, events } } + #[cfg(test)] pub async fn create( &self, - name: &str, + name: &Name, password: &Password, created_at: &DateTime, ) -> Result<Login, CreateError> { @@ -34,6 +47,14 @@ impl<'a> Logins<'a> { Ok(login.as_created()) } + + pub async fn recanonicalize(&self) -> Result<(), sqlx::Error> { + let mut tx = self.db.begin().await?; + tx.logins().recanonicalize().await?; + tx.commit().await?; + + Ok(()) + } } #[derive(Debug, thiserror::Error)] diff --git a/src/login/mod.rs b/src/login/mod.rs index 98cc3d7..64a3698 100644 --- a/src/login/mod.rs +++ b/src/login/mod.rs @@ -1,4 +1,3 @@ -#[cfg(test)] pub mod app; pub mod event; pub mod extract; diff --git a/src/login/password.rs b/src/login/password.rs index 14fd981..c27c950 100644 --- a/src/login/password.rs +++ b/src/login/password.rs @@ -4,6 +4,8 @@ use argon2::Argon2; use password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}; use rand_core::OsRng; +use crate::normalize::nfc; + #[derive(sqlx::Type)] #[sqlx(transparent)] pub struct StoredHash(String); @@ -31,7 +33,7 @@ impl fmt::Debug for StoredHash { #[derive(serde::Deserialize)] #[serde(transparent)] -pub struct Password(String); +pub struct Password(nfc::String); impl Password { pub fn hash(&self) -> Result<StoredHash, password_hash::Error> { @@ -56,9 +58,8 @@ impl fmt::Debug for Password { } } -#[cfg(test)] impl From<String> for Password { fn from(password: String) -> Self { - Self(password) + Password(password.into()) } } diff --git a/src/login/repo.rs b/src/login/repo.rs index 7d0fcb1..c6bc734 100644 --- a/src/login/repo.rs +++ b/src/login/repo.rs @@ -1,9 +1,11 @@ +use futures::stream::{StreamExt as _, TryStreamExt as _}; use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ clock::DateTime, event::{Instant, ResumePoint, Sequence}, login::{password::StoredHash, History, Id, Login}, + name::{self, Name}, }; pub trait Provider { @@ -21,48 +23,48 @@ pub struct Logins<'t>(&'t mut SqliteConnection); impl<'c> Logins<'c> { pub async fn create( &mut self, - name: &str, + name: &Name, password_hash: &StoredHash, created: &Instant, ) -> Result<History, sqlx::Error> { let id = Id::generate(); + let display_name = name.display(); + let canonical_name = name.canonical(); - let login = sqlx::query!( + sqlx::query!( r#" insert - into login (id, name, password_hash, created_sequence, created_at) - values ($1, $2, $3, $4, $5) - returning - id as "id: Id", - name, - created_sequence as "created_sequence: Sequence", - created_at as "created_at: DateTime" + into login (id, display_name, canonical_name, password_hash, created_sequence, created_at) + values ($1, $2, $3, $4, $5, $6) "#, id, - name, + display_name, + canonical_name, password_hash, created.sequence, created.at, ) - .map(|row| History { + .execute(&mut *self.0) + .await?; + + let login = History { + created: *created, login: Login { - id: row.id, - name: row.name, + id, + name: name.clone(), }, - created: Instant::new(row.created_at, row.created_sequence), - }) - .fetch_one(&mut *self.0) - .await?; + }; Ok(login) } - pub async fn all(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, sqlx::Error> { - let channels = sqlx::query!( + pub async fn all(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, LoadError> { + let logins = sqlx::query!( r#" select id as "id: Id", - name, + display_name as "display_name: String", + canonical_name as "canonical_name: String", created_sequence as "created_sequence: Sequence", created_at as "created_at: DateTime" from login @@ -71,24 +73,30 @@ impl<'c> Logins<'c> { "#, resume_at, ) - .map(|row| History { - login: Login { - id: row.id, - name: row.name, - }, - created: Instant::new(row.created_at, row.created_sequence), + .map(|row| { + Ok::<_, LoadError>(History { + login: Login { + id: row.id, + name: Name::new(row.display_name, row.canonical_name)?, + }, + created: Instant::new(row.created_at, row.created_sequence), + }) }) - .fetch_all(&mut *self.0) + .fetch(&mut *self.0) + .map(|res| res?) + .try_collect() .await?; - Ok(channels) + Ok(logins) } - pub async fn replay(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, sqlx::Error> { - let messages = sqlx::query!( + + pub async fn replay(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, LoadError> { + let logins = sqlx::query!( r#" select id as "id: Id", - name, + display_name as "display_name: String", + canonical_name as "canonical_name: String", created_sequence as "created_sequence: Sequence", created_at as "created_at: DateTime" from login @@ -96,22 +104,59 @@ impl<'c> Logins<'c> { "#, resume_at, ) - .map(|row| History { - login: Login { - id: row.id, - name: row.name, - }, - created: Instant::new(row.created_at, row.created_sequence), + .map(|row| { + Ok::<_, name::Error>(History { + login: Login { + id: row.id, + name: Name::new(row.display_name, row.canonical_name)?, + }, + created: Instant::new(row.created_at, row.created_sequence), + }) }) + .fetch(&mut *self.0) + .map(|res| Ok::<_, LoadError>(res??)) + .try_collect() + .await?; + + Ok(logins) + } + + pub async fn recanonicalize(&mut self) -> Result<(), sqlx::Error> { + let logins = sqlx::query!( + r#" + select + id as "id: Id", + display_name as "display_name: String" + from login + "#, + ) .fetch_all(&mut *self.0) .await?; - Ok(messages) + for login in logins { + let name = Name::from(login.display_name); + let canonical_name = name.canonical(); + + sqlx::query!( + r#" + update login + set canonical_name = $1 + where id = $2 + "#, + canonical_name, + login.id, + ) + .execute(&mut *self.0) + .await?; + } + + Ok(()) } } -impl<'t> From<&'t mut SqliteConnection> for Logins<'t> { - fn from(tx: &'t mut SqliteConnection) -> Self { - Self(tx) - } +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum LoadError { + Database(#[from] sqlx::Error), + Name(#[from] name::Error), } diff --git a/src/login/routes/login/post.rs b/src/login/routes/login/post.rs index 67eaa6d..20430db 100644 --- a/src/login/routes/login/post.rs +++ b/src/login/routes/login/post.rs @@ -9,6 +9,7 @@ use crate::{ clock::RequestedAt, error::Internal, login::{Login, Password}, + name::Name, token::{app, extract::IdentityToken}, }; @@ -29,7 +30,7 @@ pub async fn handler( #[derive(serde::Deserialize)] pub struct Request { - pub name: String, + pub name: Name, pub password: Password, } diff --git a/src/login/routes/logout/test.rs b/src/login/routes/logout/test.rs index 0e70e4c..91837fe 100644 --- a/src/login/routes/logout/test.rs +++ b/src/login/routes/logout/test.rs @@ -33,7 +33,6 @@ async fn successful() { assert_eq!(StatusCode::NO_CONTENT, response_status); // Verify the semantics - let error = app .tokens() .validate(&secret, &now) diff --git a/src/login/snapshot.rs b/src/login/snapshot.rs index 1a92f5c..e1eb96c 100644 --- a/src/login/snapshot.rs +++ b/src/login/snapshot.rs @@ -2,6 +2,7 @@ use super::{ event::{Created, Event}, Id, }; +use crate::name::Name; // This also implements FromRequestParts (see `./extract.rs`). As a result, it // can be used as an extractor for endpoints that want to require login, or for @@ -10,7 +11,7 @@ use super::{ #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct Login { pub id: Id, - pub name: String, + pub name: Name, // The omission of the hashed password is deliberate, to minimize the // chance that it ends up tangled up in debug output or in some other chunk // of logic elsewhere. diff --git a/src/message/app.rs b/src/message/app.rs index 4e50513..852b958 100644 --- a/src/message/app.rs +++ b/src/message/app.rs @@ -2,13 +2,14 @@ use chrono::TimeDelta; use itertools::Itertools; use sqlx::sqlite::SqlitePool; -use super::{repo::Provider as _, Id, Message}; +use super::{repo::Provider as _, Body, Id, Message}; use crate::{ channel::{self, repo::Provider as _}, clock::DateTime, db::NotFound as _, event::{repo::Provider as _, Broadcaster, Event, Sequence}, login::Login, + name, }; pub struct Messages<'a> { @@ -26,7 +27,7 @@ impl<'a> Messages<'a> { channel: &channel::Id, sender: &Login, sent_at: &DateTime, - body: &str, + body: &Body, ) -> Result<Message, SendError> { let mut tx = self.db.begin().await?; let channel = tx @@ -119,6 +120,18 @@ pub enum SendError { ChannelNotFound(channel::Id), #[error(transparent)] Database(#[from] sqlx::Error), + #[error(transparent)] + Name(#[from] name::Error), +} + +impl From<channel::repo::LoadError> for SendError { + fn from(error: channel::repo::LoadError) -> Self { + use channel::repo::LoadError; + match error { + LoadError::Database(error) => error.into(), + LoadError::Name(error) => error.into(), + } + } } #[derive(Debug, thiserror::Error)] diff --git a/src/message/body.rs b/src/message/body.rs new file mode 100644 index 0000000..6dd224c --- /dev/null +++ b/src/message/body.rs @@ -0,0 +1,30 @@ +use std::fmt; + +use crate::normalize::nfc; + +#[derive( + Clone, Debug, Default, Eq, PartialEq, serde::Deserialize, serde::Serialize, sqlx::Type, +)] +#[serde(transparent)] +#[sqlx(transparent)] +pub struct Body(nfc::String); + +impl fmt::Display for Body { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self(body) = self; + body.fmt(f) + } +} + +impl From<String> for Body { + fn from(body: String) -> Self { + Self(body.into()) + } +} + +impl From<Body> for String { + fn from(body: Body) -> Self { + let Body(body) = body; + body.into() + } +} diff --git a/src/message/mod.rs b/src/message/mod.rs index a8f51ab..c2687bc 100644 --- a/src/message/mod.rs +++ b/src/message/mod.rs @@ -1,4 +1,5 @@ pub mod app; +mod body; pub mod event; mod history; mod id; @@ -6,4 +7,6 @@ pub mod repo; mod routes; mod snapshot; -pub use self::{event::Event, history::History, id::Id, routes::router, snapshot::Message}; +pub use self::{ + body::Body, event::Event, history::History, id::Id, routes::router, snapshot::Message, +}; diff --git a/src/message/repo.rs b/src/message/repo.rs index 85a69fc..4cfefec 100644 --- a/src/message/repo.rs +++ b/src/message/repo.rs @@ -1,6 +1,6 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; -use super::{snapshot::Message, History, Id}; +use super::{snapshot::Message, Body, History, Id}; use crate::{ channel, clock::DateTime, @@ -26,24 +26,24 @@ impl<'c> Messages<'c> { channel: &channel::History, sender: &Login, sent: &Instant, - body: &str, + body: &Body, ) -> Result<History, sqlx::Error> { let id = Id::generate(); let channel_id = channel.id(); let message = sqlx::query!( r#" - insert into message - (id, channel, sender, sent_at, sent_sequence, body) - values ($1, $2, $3, $4, $5, $6) - returning - id as "id: Id", + insert into message + (id, channel, sender, sent_at, sent_sequence, body) + values ($1, $2, $3, $4, $5, $6) + returning + id as "id: Id", channel as "channel: channel::Id", sender as "sender: login::Id", sent_at as "sent_at: DateTime", sent_sequence as "sent_sequence: Sequence", - body - "#, + body as "body: Body" + "#, id, channel_id, sender.id, @@ -76,7 +76,7 @@ impl<'c> Messages<'c> { message.channel as "channel: channel::Id", message.sender as "sender: login::Id", id as "id: Id", - message.body, + message.body as "body: Body", message.sent_at as "sent_at: DateTime", message.sent_sequence as "sent_sequence: Sequence", deleted.deleted_at as "deleted_at: DateTime", @@ -113,7 +113,7 @@ impl<'c> Messages<'c> { message.channel as "channel: channel::Id", message.sender as "sender: login::Id", id as "id: Id", - message.body, + message.body as "body: Body", message.sent_at as "sent_at: DateTime", message.sent_sequence as "sent_sequence: Sequence", deleted.deleted_at as "deleted_at: DateTime", @@ -150,7 +150,7 @@ impl<'c> Messages<'c> { message.channel as "channel: channel::Id", message.sender as "sender: login::Id", id as "id: Id", - message.body, + message.body as "body: Body", message.sent_at as "sent_at: DateTime", message.sent_sequence as "sent_sequence: Sequence", deleted.deleted_at as "deleted_at?: DateTime", @@ -256,7 +256,7 @@ impl<'c> Messages<'c> { message.sender as "sender: login::Id", message.sent_at as "sent_at: DateTime", message.sent_sequence as "sent_sequence: Sequence", - message.body, + message.body as "body: Body", deleted.deleted_at as "deleted_at?: DateTime", deleted.deleted_sequence as "deleted_sequence?: Sequence" from message @@ -293,7 +293,7 @@ impl<'c> Messages<'c> { message.sender as "sender: login::Id", message.sent_at as "sent_at: DateTime", message.sent_sequence as "sent_sequence: Sequence", - message.body, + message.body as "body: Body", deleted.deleted_at as "deleted_at: DateTime", deleted.deleted_sequence as "deleted_sequence: Sequence" from message diff --git a/src/message/snapshot.rs b/src/message/snapshot.rs index 7300918..53b7176 100644 --- a/src/message/snapshot.rs +++ b/src/message/snapshot.rs @@ -1,6 +1,6 @@ use super::{ event::{Event, Sent}, - Id, + Body, Id, }; use crate::{channel, clock::DateTime, event::Instant, login}; @@ -11,7 +11,7 @@ pub struct Message { pub channel: channel::Id, pub sender: login::Id, pub id: Id, - pub body: String, + pub body: Body, #[serde(skip_serializing_if = "Option::is_none")] pub deleted_at: Option<DateTime>, } diff --git a/src/name.rs b/src/name.rs new file mode 100644 index 0000000..9187d33 --- /dev/null +++ b/src/name.rs @@ -0,0 +1,85 @@ +use std::fmt; + +use crate::normalize::{ident, nfc}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize, sqlx::Type)] +#[serde(from = "String", into = "String")] +pub struct Name { + display: nfc::String, + canonical: ident::String, +} + +impl Name { + pub fn new<D, C>(display: D, canonical: C) -> Result<Self, Error> + where + D: AsRef<str>, + C: AsRef<str>, + { + let name = Self::from(display); + + if name.canonical.as_str() == canonical.as_ref() { + Ok(name) + } else { + Err(Error::CanonicalMismatch( + canonical.as_ref().into(), + name.canonical, + name.display, + )) + } + } + + pub fn optional<D, C>(display: Option<D>, canonical: Option<C>) -> Result<Option<Self>, Error> + where + D: AsRef<str>, + C: AsRef<str>, + { + display + .zip(canonical) + .map(|(display, canonical)| Self::new(display, canonical)) + .transpose() + } + + pub fn display(&self) -> &nfc::String { + &self.display + } + + pub fn canonical(&self) -> &ident::String { + &self.canonical + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("stored canonical form {0:#?} does not match computed canonical form {:#?} for name {:#?}", .1.as_str(), .2.as_str())] + CanonicalMismatch(String, ident::String, nfc::String), +} + +impl Default for Name { + fn default() -> Self { + Self::from(String::default()) + } +} + +impl fmt::Display for Name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.display.fmt(f) + } +} + +impl<S> From<S> for Name +where + S: AsRef<str>, +{ + fn from(name: S) -> Self { + let display = nfc::String::from(&name); + let canonical = ident::String::from(&name); + + Self { display, canonical } + } +} + +impl From<Name> for String { + fn from(name: Name) -> Self { + name.display.into() + } +} diff --git a/src/normalize/mod.rs b/src/normalize/mod.rs new file mode 100644 index 0000000..6294201 --- /dev/null +++ b/src/normalize/mod.rs @@ -0,0 +1,36 @@ +mod string; + +pub mod nfc { + use std::string::String as StdString; + + use unicode_normalization::UnicodeNormalization as _; + + pub type String = super::string::String<Nfc>; + + #[derive(Clone, Debug, Default, Eq, PartialEq)] + pub struct Nfc; + + impl super::string::Normalize for Nfc { + fn normalize(&self, value: &str) -> StdString { + value.nfc().collect() + } + } +} + +pub mod ident { + use std::string::String as StdString; + + use unicode_casefold::UnicodeCaseFold as _; + use unicode_normalization::UnicodeNormalization as _; + + pub type String = super::string::String<Ident>; + + #[derive(Clone, Debug, Default, Eq, PartialEq)] + pub struct Ident; + + impl super::string::Normalize for Ident { + fn normalize(&self, value: &str) -> StdString { + value.case_fold().nfkc().collect() + } + } +} diff --git a/src/normalize/string.rs b/src/normalize/string.rs new file mode 100644 index 0000000..a0d178c --- /dev/null +++ b/src/normalize/string.rs @@ -0,0 +1,112 @@ +use std::{fmt, string::String as StdString}; + +use sqlx::{ + encode::{Encode, IsNull}, + Database, Decode, Type, +}; + +pub trait Normalize: Clone + Default { + fn normalize(&self, value: &str) -> StdString; +} + +#[derive(Clone, Debug, Default, Eq, PartialEq, serde::Serialize, serde::Deserialize)] +#[serde(into = "StdString", from = "StdString")] +#[serde(bound = "N: Normalize")] +pub struct String<N>(StdString, N); + +impl<N> fmt::Display for String<N> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self(value, _) = self; + value.fmt(f) + } +} + +impl<S, N> From<S> for String<N> +where + S: AsRef<str>, + N: Normalize, +{ + fn from(value: S) -> Self { + let normalizer = N::default(); + let value = normalizer.normalize(value.as_ref()); + + Self(value, normalizer) + } +} + +impl<N> From<String<N>> for StdString { + fn from(value: String<N>) -> Self { + let String(value, _) = value; + value + } +} + +impl<N> std::ops::Deref for String<N> { + type Target = StdString; + + fn deref(&self) -> &Self::Target { + let Self(value, _) = self; + value + } +} + +// Type is manually implemented so that we can implement Decode to do +// normalization on read. Implementation is otherwise based on +// `#[derive(sqlx::Type)]` with the `#[sqlx(transparent)]` attribute. +impl<DB, N> Type<DB> for String<N> +where + DB: Database, + StdString: Type<DB>, +{ + fn type_info() -> <DB as Database>::TypeInfo { + <StdString as Type<DB>>::type_info() + } + + fn compatible(ty: &<DB as Database>::TypeInfo) -> bool { + <StdString as Type<DB>>::compatible(ty) + } +} + +impl<'r, DB, N> Decode<'r, DB> for String<N> +where + DB: Database, + StdString: Decode<'r, DB>, + N: Normalize, +{ + fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> { + let value = StdString::decode(value)?; + Ok(Self::from(value)) + } +} + +impl<'q, DB, N> Encode<'q, DB> for String<N> +where + DB: Database, + StdString: Encode<'q, DB>, +{ + fn encode_by_ref( + &self, + buf: &mut <DB as Database>::ArgumentBuffer<'q>, + ) -> Result<IsNull, sqlx::error::BoxDynError> { + let Self(value, _) = self; + value.encode_by_ref(buf) + } + + fn encode( + self, + buf: &mut <DB as Database>::ArgumentBuffer<'q>, + ) -> Result<IsNull, sqlx::error::BoxDynError> { + let Self(value, _) = self; + value.encode(buf) + } + + fn produces(&self) -> Option<<DB as Database>::TypeInfo> { + let Self(value, _) = self; + value.produces() + } + + fn size_hint(&self) -> usize { + let Self(value, _) = self; + value.size_hint() + } +} diff --git a/src/setup/app.rs b/src/setup/app.rs index d015813..030b5f6 100644 --- a/src/setup/app.rs +++ b/src/setup/app.rs @@ -5,6 +5,7 @@ use crate::{ clock::DateTime, event::{repo::Provider as _, Broadcaster, Event}, login::{repo::Provider as _, Login, Password}, + name::Name, token::{repo::Provider as _, Secret}, }; @@ -20,7 +21,7 @@ impl<'a> Setup<'a> { pub async fn initial( &self, - name: &str, + name: &Name, password: &Password, created_at: &DateTime, ) -> Result<(Login, Secret), Error> { diff --git a/src/setup/routes/post.rs b/src/setup/routes/post.rs index 34f4ed2..fb2280a 100644 --- a/src/setup/routes/post.rs +++ b/src/setup/routes/post.rs @@ -9,6 +9,7 @@ use crate::{ clock::RequestedAt, error::Internal, login::{Login, Password}, + name::Name, setup::app, token::extract::IdentityToken, }; @@ -30,7 +31,7 @@ pub async fn handler( #[derive(serde::Deserialize)] pub struct Request { - pub name: String, + pub name: Name, pub password: Password, } diff --git a/src/test/fixtures/channel.rs b/src/test/fixtures/channel.rs index a1dda61..3831c82 100644 --- a/src/test/fixtures/channel.rs +++ b/src/test/fixtures/channel.rs @@ -11,6 +11,7 @@ use crate::{ channel::{self, Channel}, clock::RequestedAt, event::Event, + name::Name, }; pub async fn create(app: &App, created_at: &RequestedAt) -> Channel { @@ -21,13 +22,13 @@ pub async fn create(app: &App, created_at: &RequestedAt) -> Channel { .expect("should always succeed if the channel is actually new") } -pub fn propose() -> String { - rand::random::<Name>().to_string() +pub fn propose() -> Name { + rand::random::<NameTemplate>().to_string().into() } -struct Name(String); +struct NameTemplate(String); faker_impl_from_templates! { - Name; "{} {}", CityName, FullName; + NameTemplate; "{} {}", CityName, FullName; } pub fn events(event: Event) -> future::Ready<Option<channel::Event>> { diff --git a/src/test/fixtures/login.rs b/src/test/fixtures/login.rs index b6766fe..714b936 100644 --- a/src/test/fixtures/login.rs +++ b/src/test/fixtures/login.rs @@ -5,6 +5,7 @@ use crate::{ app::App, clock::RequestedAt, login::{self, Login, Password}, + name::Name, }; pub async fn create_with_password(app: &App, created_at: &RequestedAt) -> (Login, Password) { @@ -29,16 +30,16 @@ pub async fn create(app: &App, created_at: &RequestedAt) -> Login { pub fn fictitious() -> Login { Login { id: login::Id::generate(), - name: name(), + name: propose_name(), } } -pub fn propose() -> (String, Password) { - (name(), propose_password()) +pub fn propose() -> (Name, Password) { + (propose_name(), propose_password()) } -fn name() -> String { - rand::random::<internet::Username>().to_string() +fn propose_name() -> Name { + rand::random::<internet::Username>().to_string().into() } pub fn propose_password() -> Password { diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs index eb00e7c..c450bce 100644 --- a/src/test/fixtures/message.rs +++ b/src/test/fixtures/message.rs @@ -8,7 +8,7 @@ use crate::{ clock::RequestedAt, event::Event, login::Login, - message::{self, Message}, + message::{self, Body, Message}, }; pub async fn send(app: &App, channel: &Channel, login: &Login, sent_at: &RequestedAt) -> Message { @@ -20,8 +20,8 @@ pub async fn send(app: &App, channel: &Channel, login: &Login, sent_at: &Request .expect("should succeed if the channel exists") } -pub fn propose() -> String { - rand::random::<Paragraphs>().to_string() +pub fn propose() -> Body { + rand::random::<Paragraphs>().to_string().into() } pub fn events(event: Event) -> future::Ready<Option<message::Event>> { diff --git a/src/token/app.rs b/src/token/app.rs index 0dc1a46..c19d6a0 100644 --- a/src/token/app.rs +++ b/src/token/app.rs @@ -7,12 +7,14 @@ use futures::{ use sqlx::sqlite::SqlitePool; use super::{ - repo::auth::Provider as _, repo::Provider as _, Broadcaster, Event as TokenEvent, Id, Secret, + repo::{self, auth::Provider as _, Provider as _}, + Broadcaster, Event as TokenEvent, Id, Secret, }; use crate::{ clock::DateTime, db::NotFound as _, login::{Login, Password}, + name::{self, Name}, }; pub struct Tokens<'a> { @@ -27,7 +29,7 @@ impl<'a> Tokens<'a> { pub async fn login( &self, - name: &str, + name: &Name, password: &Password, login_at: &DateTime, ) -> Result<(Login, Secret), LoginError> { @@ -65,14 +67,16 @@ impl<'a> Tokens<'a> { used_at: &DateTime, ) -> Result<(Id, Login), ValidateError> { let mut tx = self.db.begin().await?; - let login = tx + let (token, login) = tx .tokens() .validate(secret, used_at) .await .not_found(|| ValidateError::InvalidToken)?; tx.commit().await?; - Ok(login) + let login = login.as_snapshot().ok_or(ValidateError::LoginDeleted)?; + + Ok((token, login)) } pub async fn limit_stream<E>( @@ -162,15 +166,40 @@ pub enum LoginError { #[error(transparent)] Database(#[from] sqlx::Error), #[error(transparent)] + Name(#[from] name::Error), + #[error(transparent)] PasswordHash(#[from] password_hash::Error), } +impl From<repo::auth::LoadError> for LoginError { + fn from(error: repo::auth::LoadError) -> Self { + use repo::auth::LoadError; + match error { + LoadError::Database(error) => error.into(), + LoadError::Name(error) => error.into(), + } + } +} + #[derive(Debug, thiserror::Error)] pub enum ValidateError { #[error("invalid token")] InvalidToken, + #[error("login deleted")] + LoginDeleted, #[error(transparent)] Database(#[from] sqlx::Error), + #[error(transparent)] + Name(#[from] name::Error), +} + +impl From<repo::LoadError> for ValidateError { + fn from(error: repo::LoadError) -> Self { + match error { + repo::LoadError::Database(error) => error.into(), + repo::LoadError::Name(error) => error.into(), + } + } } #[derive(Debug)] diff --git a/src/token/repo/auth.rs b/src/token/repo/auth.rs index 88d0878..bdc4c33 100644 --- a/src/token/repo/auth.rs +++ b/src/token/repo/auth.rs @@ -2,8 +2,10 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ clock::DateTime, + db::NotFound, event::{Instant, Sequence}, login::{self, password::StoredHash, History, Login}, + name::{self, Name}, }; pub trait Provider { @@ -19,35 +21,53 @@ impl<'c> Provider for Transaction<'c, Sqlite> { pub struct Auth<'t>(&'t mut SqliteConnection); impl<'t> Auth<'t> { - pub async fn for_name(&mut self, name: &str) -> Result<(History, StoredHash), sqlx::Error> { - let found = sqlx::query!( + pub async fn for_name(&mut self, name: &Name) -> Result<(History, StoredHash), LoadError> { + let name = name.canonical(); + let row = sqlx::query!( r#" - select - id as "id: login::Id", - name, - password_hash as "password_hash: StoredHash", + select + id as "id: login::Id", + display_name as "display_name: String", + canonical_name as "canonical_name: String", created_sequence as "created_sequence: Sequence", - created_at as "created_at: DateTime" - from login - where name = $1 - "#, + created_at as "created_at: DateTime", + password_hash as "password_hash: StoredHash" + from login + where canonical_name = $1 + "#, name, ) - .map(|row| { - ( - History { - login: Login { - id: row.id, - name: row.name, - }, - created: Instant::new(row.created_at, row.created_sequence), - }, - row.password_hash, - ) - }) .fetch_one(&mut *self.0) .await?; - Ok(found) + let login = History { + login: Login { + id: row.id, + name: Name::new(row.display_name, row.canonical_name)?, + }, + created: Instant::new(row.created_at, row.created_sequence), + }; + + Ok((login, row.password_hash)) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum LoadError { + Database(#[from] sqlx::Error), + Name(#[from] name::Error), +} + +impl<T> NotFound for Result<T, LoadError> { + type Ok = T; + type Error = LoadError; + + fn optional(self) -> Result<Option<T>, LoadError> { + match self { + Ok(value) => Ok(Some(value)), + Err(LoadError::Database(sqlx::Error::RowNotFound)) => Ok(None), + Err(other) => Err(other), + } } } diff --git a/src/token/repo/mod.rs b/src/token/repo/mod.rs index 9169743..d8463eb 100644 --- a/src/token/repo/mod.rs +++ b/src/token/repo/mod.rs @@ -1,4 +1,4 @@ pub mod auth; mod token; -pub use self::token::Provider; +pub use self::token::{LoadError, Provider}; diff --git a/src/token/repo/token.rs b/src/token/repo/token.rs index c592dcd..35ea385 100644 --- a/src/token/repo/token.rs +++ b/src/token/repo/token.rs @@ -3,7 +3,10 @@ use uuid::Uuid; use crate::{ clock::DateTime, + db::NotFound, + event::{Instant, Sequence}, login::{self, History, Login}, + name::{self, Name}, token::{Id, Secret}, }; @@ -100,53 +103,78 @@ impl<'c> Tokens<'c> { } // Validate a token by its secret, retrieving the associated Login record. - // Will return [None] if the token is not valid. The token's last-used - // timestamp will be set to `used_at`. + // Will return an error if the token is not valid. If successful, the + // retrieved token's last-used timestamp will be set to `used_at`. pub async fn validate( &mut self, secret: &Secret, used_at: &DateTime, - ) -> Result<(Id, Login), sqlx::Error> { + ) -> Result<(Id, History), LoadError> { // I would use `update … returning` to do this in one query, but // sqlite3, as of this writing, does not allow an update's `returning` // clause to reference columns from tables joined into the update. Two // queries is fine, but it feels untidy. - sqlx::query!( + let (token, login) = sqlx::query!( r#" update token set last_used_at = $1 where secret = $2 + returning + id as "token: Id", + login as "login: login::Id" "#, used_at, secret, ) - .execute(&mut *self.0) + .map(|row| (row.token, row.login)) + .fetch_one(&mut *self.0) .await?; let login = sqlx::query!( r#" select - token.id as "token_id: Id", - login.id as "login_id: login::Id", - login.name as "login_name" + id as "id: login::Id", + display_name as "display_name: String", + canonical_name as "canonical_name: String", + created_sequence as "created_sequence: Sequence", + created_at as "created_at: DateTime" from login - join token on login.id = token.login - where token.secret = $1 + where id = $1 "#, - secret, + login, ) .map(|row| { - ( - row.token_id, - Login { - id: row.login_id, - name: row.login_name, + Ok::<_, name::Error>(History { + login: Login { + id: row.id, + name: Name::new(row.display_name, row.canonical_name)?, }, - ) + created: Instant::new(row.created_at, row.created_sequence), + }) }) .fetch_one(&mut *self.0) - .await?; + .await??; + + Ok((token, login)) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum LoadError { + Database(#[from] sqlx::Error), + Name(#[from] name::Error), +} + +impl<T> NotFound for Result<T, LoadError> { + type Ok = T; + type Error = LoadError; - Ok(login) + fn optional(self) -> Result<Option<T>, LoadError> { + match self { + Ok(value) => Ok(Some(value)), + Err(LoadError::Database(sqlx::Error::RowNotFound)) => Ok(None), + Err(other) => Err(other), + } } } |
