diff options
Diffstat (limited to 'src')
59 files changed, 1065 insertions, 411 deletions
@@ -5,12 +5,14 @@ use crate::{ channel::app::Channels, event::{self, app::Events}, invite::app::Invites, - login::app::Logins, message::app::Messages, setup::app::Setup, token::{self, app::Tokens}, }; +#[cfg(test)] +use crate::login::app::Logins; + #[derive(Clone)] pub struct App { db: SqlitePool, @@ -47,11 +49,6 @@ impl App { Invites::new(&self.db, &self.events) } - #[cfg(not(test))] - pub const fn logins(&self) -> Logins { - Logins::new(&self.db) - } - #[cfg(test)] pub const fn logins(&self) -> Logins { Logins::new(&self.db, &self.events) diff --git a/src/bin/hi-recanonicalize.rs b/src/bin/hi-recanonicalize.rs deleted file mode 100644 index 4081276..0000000 --- a/src/bin/hi-recanonicalize.rs +++ /dev/null @@ -1,9 +0,0 @@ -use clap::Parser; - -use hi::cli; - -#[tokio::main] -async fn main() -> Result<(), cli::recanonicalize::Error> { - let args = cli::recanonicalize::Args::parse(); - args.run().await -} diff --git a/src/boot/app.rs b/src/boot/app.rs index e716b58..909f7d8 100644 --- a/src/boot/app.rs +++ b/src/boot/app.rs @@ -22,9 +22,9 @@ impl<'a> Boot<'a> { let mut tx = self.db.begin().await?; let resume_point = tx.sequence().current().await?; - let logins = tx.logins().all(resume_point.into()).await?; + let logins = tx.logins().all(resume_point).await?; let channels = tx.channels().all(resume_point).await?; - let messages = tx.messages().all(resume_point.into()).await?; + let messages = tx.messages().all(resume_point).await?; tx.commit().await?; diff --git a/src/boot/routes/test.rs b/src/boot/routes/test.rs index 8808b70..202dcb9 100644 --- a/src/boot/routes/test.rs +++ b/src/boot/routes/test.rs @@ -85,7 +85,7 @@ async fn excludes_deleted_messages() { let deleted_message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; app.messages() - .delete(&deleted_message.id, &fixtures::now()) + .delete(&sender, &deleted_message.id, &fixtures::now()) .await .expect("deleting valid message succeeds"); diff --git a/src/channel/app.rs b/src/channel/app.rs index 8359277..21784e9 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -4,13 +4,13 @@ use sqlx::sqlite::SqlitePool; use super::{ repo::{LoadError, Provider as _}, - Channel, Id, + validate, Channel, Id, }; use crate::{ clock::DateTime, db::{Duplicate as _, NotFound as _}, event::{repo::Provider as _, Broadcaster, Event, Sequence}, - message::repo::Provider as _, + message::{self, repo::Provider as _}, name::{self, Name}, }; @@ -25,6 +25,10 @@ impl<'a> Channels<'a> { } pub async fn create(&self, name: &Name, created_at: &DateTime) -> Result<Channel, CreateError> { + if !validate::name(name) { + return Err(CreateError::InvalidName(name.clone())); + } + let mut tx = self.db.begin().await?; let created = tx.sequence().next(created_at).await?; let channel = tx @@ -44,38 +48,36 @@ impl<'a> Channels<'a> { // it exists in the specific moment when you call it. pub async fn get(&self, channel: &Id) -> Result<Channel, Error> { let not_found = || Error::NotFound(channel.clone()); + let deleted = || Error::Deleted(channel.clone()); let mut tx = self.db.begin().await?; let channel = tx.channels().by_id(channel).await.not_found(not_found)?; tx.commit().await?; - channel.as_snapshot().ok_or_else(not_found) + channel.as_snapshot().ok_or_else(deleted) } - pub async fn delete(&self, channel: &Id, deleted_at: &DateTime) -> Result<(), Error> { + pub async fn delete(&self, channel: &Id, deleted_at: &DateTime) -> Result<(), DeleteError> { let mut tx = self.db.begin().await?; let channel = tx .channels() .by_id(channel) .await - .not_found(|| Error::NotFound(channel.clone()))?; + .not_found(|| DeleteError::NotFound(channel.clone()))?; channel .as_snapshot() - .ok_or_else(|| Error::Deleted(channel.id().clone()))?; + .ok_or_else(|| DeleteError::Deleted(channel.id().clone()))?; let mut events = Vec::new(); let messages = tx.messages().live(&channel).await?; - for message in messages { - let deleted = tx.sequence().next(deleted_at).await?; - let message = tx.messages().delete(&message, &deleted).await?; - events.extend( - message - .events() - .filter(Sequence::start_from(deleted.sequence)) - .map(Event::from), - ); + let has_messages = messages + .iter() + .map(message::History::as_snapshot) + .any(|message| message.is_some()); + if has_messages { + return Err(DeleteError::NotEmpty(channel.id().clone())); } let deleted = tx.sequence().next(deleted_at).await?; @@ -135,20 +137,14 @@ impl<'a> Channels<'a> { Ok(()) } - - pub async fn recanonicalize(&self) -> Result<(), sqlx::Error> { - let mut tx = self.db.begin().await?; - tx.channels().recanonicalize().await?; - tx.commit().await?; - - Ok(()) - } } #[derive(Debug, thiserror::Error)] pub enum CreateError { #[error("channel named {0} already exists")] DuplicateName(Name), + #[error("invalid channel name: {0}")] + InvalidName(Name), #[error(transparent)] Database(#[from] sqlx::Error), #[error(transparent)] @@ -186,6 +182,29 @@ impl From<LoadError> for Error { } #[derive(Debug, thiserror::Error)] +pub enum DeleteError { + #[error("channel {0} not found")] + NotFound(Id), + #[error("channel {0} deleted")] + Deleted(Id), + #[error("channel {0} not empty")] + NotEmpty(Id), + #[error(transparent)] + Database(#[from] sqlx::Error), + #[error(transparent)] + Name(#[from] name::Error), +} + +impl From<LoadError> for DeleteError { + fn from(error: LoadError) -> Self { + match error { + LoadError::Database(error) => error.into(), + LoadError::Name(error) => error.into(), + } + } +} + +#[derive(Debug, thiserror::Error)] pub enum ExpireError { #[error(transparent)] Database(#[from] sqlx::Error), diff --git a/src/channel/history.rs b/src/channel/history.rs index 4b9fcc7..ef2120d 100644 --- a/src/channel/history.rs +++ b/src/channel/history.rs @@ -1,8 +1,10 @@ +use itertools::Itertools as _; + use super::{ event::{Created, Deleted, Event}, Channel, Id, }; -use crate::event::{Instant, ResumePoint, Sequence}; +use crate::event::{Instant, Sequence}; #[derive(Clone, Debug, Eq, PartialEq)] pub struct History { @@ -26,9 +28,9 @@ impl History { self.channel.clone() } - pub fn as_of(&self, resume_point: impl Into<ResumePoint>) -> Option<Channel> { + pub fn as_of(&self, resume_point: Sequence) -> Option<Channel> { self.events() - .filter(Sequence::up_to(resume_point.into())) + .filter(Sequence::up_to(resume_point)) .collect() } @@ -41,7 +43,9 @@ impl History { // Event factories impl History { pub fn events(&self) -> impl Iterator<Item = Event> { - [self.created()].into_iter().chain(self.deleted()) + [self.created()] + .into_iter() + .merge_by(self.deleted(), Sequence::merge) } fn created(&self) -> Event { diff --git a/src/channel/mod.rs b/src/channel/mod.rs index eb8200b..d5ba828 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -5,5 +5,6 @@ mod id; pub mod repo; mod routes; mod snapshot; +mod validate; pub use self::{event::Event, history::History, id::Id, routes::router, snapshot::Channel}; diff --git a/src/channel/repo.rs b/src/channel/repo.rs index a49db52..6612151 100644 --- a/src/channel/repo.rs +++ b/src/channel/repo.rs @@ -5,7 +5,7 @@ use crate::{ channel::{Channel, History, Id}, clock::DateTime, db::NotFound, - event::{Instant, ResumePoint, Sequence}, + event::{Instant, Sequence}, name::{self, Name}, }; @@ -32,12 +32,13 @@ impl<'c> Channels<'c> { sqlx::query!( r#" insert - into channel (id, created_at, created_sequence) - values ($1, $2, $3) + into channel (id, created_at, created_sequence, last_sequence) + values ($1, $2, $3, $4) "#, id, created.at, created.sequence, + created.sequence, ) .execute(&mut *self.0) .await?; @@ -144,13 +145,13 @@ impl<'c> Channels<'c> { Ok(channels) } - pub async fn replay(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, LoadError> { + pub async fn replay(&mut self, resume_at: Sequence) -> Result<Vec<History>, LoadError> { let channels = sqlx::query!( r#" select id as "id: Id", - name.display_name as "display_name: String", - name.canonical_name as "canonical_name: String", + name.display_name as "display_name?: String", + name.canonical_name as "canonical_name?: String", channel.created_at as "created_at: DateTime", channel.created_sequence as "created_sequence: Sequence", deleted.deleted_at as "deleted_at?: DateTime", @@ -160,7 +161,7 @@ impl<'c> Channels<'c> { using (id) left join channel_deleted as deleted using (id) - where coalesce(channel.created_sequence > $1, true) + where channel.last_sequence > $1 "#, resume_at, ) @@ -191,6 +192,19 @@ impl<'c> Channels<'c> { let id = channel.id(); sqlx::query!( r#" + update channel + set last_sequence = max(last_sequence, $1) + where id = $2 + returning id as "id: Id" + "#, + deleted.sequence, + id, + ) + .fetch_one(&mut *self.0) + .await?; + + sqlx::query!( + r#" insert into channel_deleted (id, deleted_at, deleted_sequence) values ($1, $2, $3) "#, @@ -300,38 +314,6 @@ impl<'c> Channels<'c> { Ok(channels) } - - pub async fn recanonicalize(&mut self) -> Result<(), sqlx::Error> { - let channels = sqlx::query!( - r#" - select - id as "id: Id", - display_name as "display_name: String" - from channel_name - "#, - ) - .fetch_all(&mut *self.0) - .await?; - - for channel in channels { - let name = Name::from(channel.display_name); - let canonical_name = name.canonical(); - - sqlx::query!( - r#" - update channel_name - set canonical_name = $1 - where id = $2 - "#, - canonical_name, - channel.id, - ) - .execute(&mut *self.0) - .await?; - } - - Ok(()) - } } #[derive(Debug, thiserror::Error)] diff --git a/src/channel/routes/channel/delete.rs b/src/channel/routes/channel/delete.rs index 2d2b5f1..9c093c1 100644 --- a/src/channel/routes/channel/delete.rs +++ b/src/channel/routes/channel/delete.rs @@ -36,14 +36,19 @@ impl IntoResponse for Response { #[derive(Debug, thiserror::Error)] #[error(transparent)] -pub struct Error(#[from] pub app::Error); +pub struct Error(#[from] pub app::DeleteError); impl IntoResponse for Error { fn into_response(self) -> response::Response { let Self(error) = self; #[allow(clippy::match_wildcard_for_single_variants)] match error { - app::Error::NotFound(_) | app::Error::Deleted(_) => NotFound(error).into_response(), + app::DeleteError::NotFound(_) | app::DeleteError::Deleted(_) => { + NotFound(error).into_response() + } + app::DeleteError::NotEmpty(_) => { + (StatusCode::CONFLICT, error.to_string()).into_response() + } other => Internal::from(other).into_response(), } } diff --git a/src/channel/routes/channel/test/delete.rs b/src/channel/routes/channel/test/delete.rs index 0371b0a..77a0b03 100644 --- a/src/channel/routes/channel/test/delete.rs +++ b/src/channel/routes/channel/test/delete.rs @@ -55,7 +55,7 @@ pub async fn invalid_channel_id() { // Verify the response - assert!(matches!(error, app::Error::NotFound(id) if id == channel)); + assert!(matches!(error, app::DeleteError::NotFound(id) if id == channel)); } #[tokio::test] @@ -84,7 +84,7 @@ pub async fn channel_deleted() { // Verify the response - assert!(matches!(error, app::Error::Deleted(id) if id == channel.id)); + assert!(matches!(error, app::DeleteError::Deleted(id) if id == channel.id)); } #[tokio::test] @@ -113,7 +113,7 @@ pub async fn channel_expired() { // Verify the response - assert!(matches!(error, app::Error::Deleted(id) if id == channel.id)); + assert!(matches!(error, app::DeleteError::Deleted(id) if id == channel.id)); } #[tokio::test] @@ -147,5 +147,31 @@ pub async fn channel_purged() { // Verify the response - assert!(matches!(error, app::Error::NotFound(id) if id == channel.id)); + assert!(matches!(error, app::DeleteError::NotFound(id) if id == channel.id)); +} + +#[tokio::test] +pub async fn channel_not_empty() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app, &fixtures::now()).await; + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; + + // Send the request + + let deleter = fixtures::identity::create(&app, &fixtures::now()).await; + let delete::Error(error) = delete::handler( + State(app.clone()), + Path(channel.id.clone()), + fixtures::now(), + deleter, + ) + .await + .expect_err("deleting a channel with messages fails"); + + // Verify the response + + assert!(matches!(error, app::DeleteError::NotEmpty(id) if id == channel.id)); } diff --git a/src/channel/routes/channel/test/post.rs b/src/channel/routes/channel/test/post.rs index 111a703..bc0684b 100644 --- a/src/channel/routes/channel/test/post.rs +++ b/src/channel/routes/channel/test/post.rs @@ -15,6 +15,7 @@ async fn messages_in_order() { let app = fixtures::scratch_app().await; let sender = fixtures::identity::create(&app, &fixtures::now()).await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Call the endpoint (twice) @@ -41,7 +42,7 @@ async fn messages_in_order() { let mut events = app .events() - .subscribe(None) + .subscribe(resume_point) .await .expect("subscribing to a valid channel succeeds") .filter_map(fixtures::event::message) diff --git a/src/channel/routes/post.rs b/src/channel/routes/post.rs index 810445c..2cf1cc0 100644 --- a/src/channel/routes/post.rs +++ b/src/channel/routes/post.rs @@ -54,6 +54,9 @@ impl IntoResponse for Error { app::CreateError::DuplicateName(_) => { (StatusCode::CONFLICT, error.to_string()).into_response() } + app::CreateError::InvalidName(_) => { + (StatusCode::BAD_REQUEST, error.to_string()).into_response() + } other => Internal::from(other).into_response(), } } diff --git a/src/channel/routes/test.rs b/src/channel/routes/test.rs index 10b1e8d..cba8f2e 100644 --- a/src/channel/routes/test.rs +++ b/src/channel/routes/test.rs @@ -16,6 +16,7 @@ async fn new_channel() { let app = fixtures::scratch_app().await; let creator = fixtures::identity::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Call the endpoint @@ -44,7 +45,7 @@ async fn new_channel() { let mut events = app .events() - .subscribe(None) + .subscribe(resume_point) .await .expect("subscribing never fails") .filter_map(fixtures::event::channel) @@ -116,6 +117,30 @@ async fn conflicting_canonical_name() { } #[tokio::test] +async fn invalid_name() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let creator = fixtures::identity::create(&app, &fixtures::now()).await; + + // Call the endpoint + + let name = fixtures::channel::propose_invalid_name(); + let request = post::Request { name: name.clone() }; + let post::Error(error) = + post::handler(State(app.clone()), creator, fixtures::now(), Json(request)) + .await + .expect_err("invalid channel name should fail the request"); + + // Verify the structure of the response + + assert!(matches!( + error, + app::CreateError::InvalidName(error_name) if name == error_name + )); +} + +#[tokio::test] async fn name_reusable_after_delete() { // Set up the environment diff --git a/src/channel/validate.rs b/src/channel/validate.rs new file mode 100644 index 0000000..0c97293 --- /dev/null +++ b/src/channel/validate.rs @@ -0,0 +1,23 @@ +use unicode_segmentation::UnicodeSegmentation as _; + +use crate::name::Name; + +// Picked out of a hat. The power of two is not meaningful. +const NAME_TOO_LONG: usize = 64; + +pub fn name(name: &Name) -> bool { + let display = name.display(); + + [ + display.graphemes(true).count() < NAME_TOO_LONG, + display.chars().all(|ch| !ch.is_control()), + display.chars().next().is_some_and(|c| !c.is_whitespace()), + display.chars().last().is_some_and(|c| !c.is_whitespace()), + display + .chars() + .zip(display.chars().skip(1)) + .all(|(a, b)| !(a.is_whitespace() && b.is_whitespace())), + ] + .into_iter() + .all(|value| value) +} diff --git a/src/cli/mod.rs b/src/cli.rs index c75ce2b..0659851 100644 --- a/src/cli/mod.rs +++ b/src/cli.rs @@ -22,8 +22,6 @@ use crate::{ ui, }; -pub mod recanonicalize; - /// Command-line entry point for running the `hi` server. /// /// This is intended to be used as a Clap [Parser], to capture command-line diff --git a/src/cli/recanonicalize.rs b/src/cli/recanonicalize.rs deleted file mode 100644 index 9db5b77..0000000 --- a/src/cli/recanonicalize.rs +++ /dev/null @@ -1,86 +0,0 @@ -use sqlx::sqlite::SqlitePool; - -use crate::{app::App, db}; - -/// Command-line entry point for repairing canonical names in the `hi` database. -/// This command may be necessary after an upgrade, if the canonical forms of -/// names has changed. It will re-calculate the canonical form of each name in -/// the database, based on its display form, and store the results back to the -/// database. -/// -/// This is intended to be used as a Clap [Parser], to capture command-line -/// arguments for the `hi-recanonicalize` command: -/// -/// ```no_run -/// # use hi::cli::recanonicalize::Error; -/// # -/// # #[tokio::main] -/// # async fn main() -> Result<(), Error> { -/// use clap::Parser; -/// use hi::cli::recanonicalize::Args; -/// -/// let args = Args::parse(); -/// args.run().await?; -/// # Ok(()) -/// # } -/// ``` -#[derive(clap::Parser)] -#[command( - version, - about = "Recanonicalize names in the `hi` database.", - long_about = r#"Recanonicalize names in the `hi` database. - -The `hi` server must not be running while this command is run. - -The database at `--database-url` will also be created, or upgraded, automatically."# -)] -pub struct Args { - /// Sqlite URL or path for the `hi` database - #[arg(short, long, env, default_value = "sqlite://.hi")] - database_url: String, - - /// Sqlite URL or path for a backup of the `hi` database during upgrades - #[arg(short = 'D', long, env, default_value = "sqlite://.hi.backup")] - backup_database_url: String, -} - -impl Args { - /// Recanonicalizes the `hi` database, using the parsed configuation in - /// `self`. - /// - /// This will perform the following tasks: - /// - /// * Migrate the `hi` database (at `--database-url`). - /// * Recanonicalize names in the `login` and `channel` tables. - /// - /// # Errors - /// - /// Will return `Err` if the canonicalization or database upgrade processes - /// fail. The specific [`Error`] variant will expose the cause - /// of the failure. - pub async fn run(self) -> Result<(), Error> { - let pool = self.pool().await?; - - let app = App::from(pool); - app.logins().recanonicalize().await?; - app.channels().recanonicalize().await?; - - Ok(()) - } - - async fn pool(&self) -> Result<SqlitePool, db::Error> { - db::prepare(&self.database_url, &self.backup_database_url).await - } -} - -/// Errors that can be raised by [`Args::run`]. -#[derive(Debug, thiserror::Error)] -#[error(transparent)] -pub enum Error { - // /// Failure due to `io::Error`. See [`io::Error`]. - // Io(#[from] io::Error), - /// Failure due to a database initialization error. See [`db::Error`]. - Database(#[from] db::Error), - /// Failure due to a data manipulation error. See [`sqlx::Error`]. - Sqlx(#[from] sqlx::Error), -} diff --git a/src/event/app.rs b/src/event/app.rs index c754388..b309245 100644 --- a/src/event/app.rs +++ b/src/event/app.rs @@ -6,7 +6,7 @@ use futures::{ use itertools::Itertools as _; use sqlx::sqlite::SqlitePool; -use super::{broadcaster::Broadcaster, Event, ResumePoint, Sequence, Sequenced}; +use super::{broadcaster::Broadcaster, Event, Sequence, Sequenced}; use crate::{ channel::{self, repo::Provider as _}, login::{self, repo::Provider as _}, @@ -26,9 +26,8 @@ impl<'a> Events<'a> { pub async fn subscribe( &self, - resume_at: impl Into<ResumePoint>, + resume_at: Sequence, ) -> Result<impl Stream<Item = Event> + std::fmt::Debug, Error> { - let resume_at = resume_at.into(); // Subscribe before retrieving, to catch messages broadcast while we're // querying the DB. We'll prune out duplicates later. let live_messages = self.events.subscribe(); @@ -63,7 +62,7 @@ impl<'a> Events<'a> { .merge_by(channel_events, Sequence::merge) .merge_by(message_events, Sequence::merge) .collect::<Vec<_>>(); - let resume_live_at = replay_events.last().map(Sequenced::sequence); + let resume_live_at = replay_events.last().map_or(resume_at, Sequenced::sequence); let replay = stream::iter(replay_events); @@ -77,7 +76,7 @@ impl<'a> Events<'a> { Ok(replay.chain(live_messages)) } - fn resume(resume_at: ResumePoint) -> impl for<'m> FnMut(&'m Event) -> future::Ready<bool> { + fn resume(resume_at: Sequence) -> impl for<'m> FnMut(&'m Event) -> future::Ready<bool> { let filter = Sequence::after(resume_at); move |event| future::ready(filter(event)) } diff --git a/src/event/mod.rs b/src/event/mod.rs index 69c7a10..9996916 100644 --- a/src/event/mod.rs +++ b/src/event/mod.rs @@ -13,8 +13,6 @@ pub use self::{ sequence::{Instant, Sequence, Sequenced}, }; -pub type ResumePoint = Option<Sequence>; - #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum Event { diff --git a/src/event/routes/get.rs b/src/event/routes/get.rs index 22e8762..ceebcc9 100644 --- a/src/event/routes/get.rs +++ b/src/event/routes/get.rs @@ -12,7 +12,7 @@ use futures::stream::{Stream, StreamExt as _}; use crate::{ app::App, error::{Internal, Unauthorized}, - event::{app, extract::LastEventId, Event, ResumePoint, Sequence, Sequenced as _}, + event::{app, extract::LastEventId, Event, Sequence, Sequenced as _}, token::{app::ValidateError, extract::Identity}, }; @@ -22,9 +22,7 @@ pub async fn handler( last_event_id: Option<LastEventId<Sequence>>, Query(query): Query<QueryParams>, ) -> Result<Response<impl Stream<Item = Event> + std::fmt::Debug>, Error> { - let resume_at = last_event_id - .map(LastEventId::into_inner) - .or(query.resume_point); + let resume_at = last_event_id.map_or(query.resume_point, LastEventId::into_inner); let stream = app.events().subscribe(resume_at).await?; let stream = app.tokens().limit_stream(identity.token, stream).await?; @@ -32,9 +30,9 @@ pub async fn handler( Ok(Response(stream)) } -#[derive(Default, serde::Deserialize)] +#[derive(serde::Deserialize)] pub struct QueryParams { - pub resume_point: ResumePoint, + pub resume_point: Sequence, } #[derive(Debug)] diff --git a/src/event/routes/test/channel.rs b/src/event/routes/test/channel.rs index 6a0a803..0695ab1 100644 --- a/src/event/routes/test/channel.rs +++ b/src/event/routes/test/channel.rs @@ -12,14 +12,19 @@ async fn creating() { // Set up the environment let app = fixtures::scratch_app().await; + let resume_point = fixtures::boot::resume_point(&app).await; // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Create a channel @@ -46,6 +51,7 @@ async fn previously_created() { // Set up the environment let app = fixtures::scratch_app().await; + let resume_point = fixtures::boot::resume_point(&app).await; // Create a channel @@ -59,10 +65,14 @@ async fn previously_created() { // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Verify channel created event @@ -81,14 +91,19 @@ async fn expiring() { let app = fixtures::scratch_app().await; let channel = fixtures::channel::create(&app, &fixtures::ancient()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Expire channels @@ -113,6 +128,7 @@ async fn previously_expired() { let app = fixtures::scratch_app().await; let channel = fixtures::channel::create(&app, &fixtures::ancient()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Expire channels @@ -124,10 +140,14 @@ async fn previously_expired() { // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Check for expiry event let _ = events @@ -145,14 +165,19 @@ async fn deleting() { let app = fixtures::scratch_app().await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Delete the channel @@ -177,6 +202,7 @@ async fn previously_deleted() { let app = fixtures::scratch_app().await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Delete the channel @@ -188,10 +214,14 @@ async fn previously_deleted() { // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Check for expiry event let _ = events @@ -209,6 +239,7 @@ async fn previously_purged() { let app = fixtures::scratch_app().await; let channel = fixtures::channel::create(&app, &fixtures::ancient()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Delete and purge the channel @@ -225,10 +256,14 @@ async fn previously_purged() { // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Check for expiry event events diff --git a/src/event/routes/test/invite.rs b/src/event/routes/test/invite.rs index d24f474..73af62d 100644 --- a/src/event/routes/test/invite.rs +++ b/src/event/routes/test/invite.rs @@ -14,14 +14,19 @@ async fn accepting_invite() { let app = fixtures::scratch_app().await; let issuer = fixtures::login::create(&app, &fixtures::now()).await; let invite = fixtures::invite::issue(&app, &issuer, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Accept the invite @@ -50,6 +55,7 @@ async fn previously_accepted_invite() { let app = fixtures::scratch_app().await; let issuer = fixtures::login::create(&app, &fixtures::now()).await; let invite = fixtures::invite::issue(&app, &issuer, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Accept the invite @@ -63,10 +69,14 @@ async fn previously_accepted_invite() { // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Expect a login created event diff --git a/src/event/routes/test/message.rs b/src/event/routes/test/message.rs index 63a3f43..fafaeb3 100644 --- a/src/event/routes/test/message.rs +++ b/src/event/routes/test/message.rs @@ -16,14 +16,19 @@ async fn sending() { let app = fixtures::scratch_app().await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Call the endpoint let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Send a message @@ -56,6 +61,7 @@ async fn previously_sent() { let app = fixtures::scratch_app().await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Send a message @@ -74,10 +80,14 @@ async fn previously_sent() { // Call the endpoint let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Verify that an event is delivered @@ -96,6 +106,7 @@ async fn sent_in_multiple_channels() { let app = fixtures::scratch_app().await; let sender = fixtures::login::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; let channels = [ fixtures::channel::create(&app, &fixtures::now()).await, @@ -115,9 +126,14 @@ async fn sent_in_multiple_channels() { // Call the endpoint let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = get::handler(State(app), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Verify the structure of the response. @@ -141,6 +157,7 @@ async fn sent_sequentially() { let app = fixtures::scratch_app().await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; let sender = fixtures::login::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; let messages = vec![ fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, @@ -151,9 +168,14 @@ async fn sent_sequentially() { // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = get::handler(State(app), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Verify the expected events in the expected order @@ -180,14 +202,19 @@ async fn expiring() { let channel = fixtures::channel::create(&app, &fixtures::ancient()).await; let sender = fixtures::login::create(&app, &fixtures::ancient()).await; let message = fixtures::message::send(&app, &channel, &sender, &fixtures::ancient()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Expire messages @@ -214,6 +241,7 @@ async fn previously_expired() { let channel = fixtures::channel::create(&app, &fixtures::ancient()).await; let sender = fixtures::login::create(&app, &fixtures::ancient()).await; let message = fixtures::message::send(&app, &channel, &sender, &fixtures::ancient()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Expire messages @@ -225,10 +253,14 @@ async fn previously_expired() { // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Check for expiry event let _ = events @@ -248,19 +280,24 @@ async fn deleting() { let channel = fixtures::channel::create(&app, &fixtures::now()).await; let sender = fixtures::login::create(&app, &fixtures::now()).await; let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Delete the message app.messages() - .delete(&message.id, &fixtures::now()) + .delete(&sender, &message.id, &fixtures::now()) .await .expect("deleting a valid message succeeds"); @@ -282,21 +319,26 @@ async fn previously_deleted() { let channel = fixtures::channel::create(&app, &fixtures::now()).await; let sender = fixtures::login::create(&app, &fixtures::now()).await; let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Delete the message app.messages() - .delete(&message.id, &fixtures::now()) + .delete(&sender, &message.id, &fixtures::now()) .await .expect("deleting a valid message succeeds"); // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Check for delete event let _ = events @@ -316,11 +358,12 @@ async fn previously_purged() { let channel = fixtures::channel::create(&app, &fixtures::ancient()).await; let sender = fixtures::login::create(&app, &fixtures::ancient()).await; let message = fixtures::message::send(&app, &channel, &sender, &fixtures::ancient()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Purge the message app.messages() - .delete(&message.id, &fixtures::ancient()) + .delete(&sender, &message.id, &fixtures::ancient()) .await .expect("deleting a valid message succeeds"); @@ -332,10 +375,14 @@ async fn previously_purged() { // Subscribe let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Check for delete event diff --git a/src/event/routes/test/resume.rs b/src/event/routes/test/resume.rs index 62b9bad..fabda0c 100644 --- a/src/event/routes/test/resume.rs +++ b/src/event/routes/test/resume.rs @@ -16,6 +16,7 @@ async fn resume() { let app = fixtures::scratch_app().await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; let sender = fixtures::login::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; let initial_message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; @@ -34,7 +35,7 @@ async fn resume() { State(app.clone()), subscriber.clone(), None, - Query::default(), + Query(get::QueryParams { resume_point }), ) .await .expect("subscribe never fails"); @@ -55,7 +56,7 @@ async fn resume() { State(app), subscriber, Some(resume_at.into()), - Query::default(), + Query(get::QueryParams { resume_point }), ) .await .expect("subscribe never fails"); @@ -98,6 +99,7 @@ async fn serial_resume() { let sender = fixtures::login::create(&app, &fixtures::now()).await; let channel_a = fixtures::channel::create(&app, &fixtures::now()).await; let channel_b = fixtures::channel::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Call the endpoint @@ -115,7 +117,7 @@ async fn serial_resume() { State(app.clone()), subscriber.clone(), None, - Query::default(), + Query(get::QueryParams { resume_point }), ) .await .expect("subscribe never fails"); @@ -156,7 +158,7 @@ async fn serial_resume() { State(app.clone()), subscriber.clone(), Some(resume_at.into()), - Query::default(), + Query(get::QueryParams { resume_point }), ) .await .expect("subscribe never fails"); @@ -197,7 +199,7 @@ async fn serial_resume() { State(app.clone()), subscriber.clone(), Some(resume_at.into()), - Query::default(), + Query(get::QueryParams { resume_point }), ) .await .expect("subscribe never fails"); diff --git a/src/event/routes/test/setup.rs b/src/event/routes/test/setup.rs index 007b03d..26b7ea7 100644 --- a/src/event/routes/test/setup.rs +++ b/src/event/routes/test/setup.rs @@ -15,6 +15,7 @@ async fn previously_completed() { // Set up the environment let app = fixtures::scratch_app().await; + let resume_point = fixtures::boot::resume_point(&app).await; // Complete initial setup @@ -28,10 +29,14 @@ async fn previously_completed() { // Subscribe to events let subscriber = fixtures::identity::create(&app, &fixtures::now()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Expect a login created event diff --git a/src/event/routes/test/token.rs b/src/event/routes/test/token.rs index 2039d9b..fa76865 100644 --- a/src/event/routes/test/token.rs +++ b/src/event/routes/test/token.rs @@ -14,6 +14,7 @@ async fn terminates_on_token_expiry() { let app = fixtures::scratch_app().await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; let sender = fixtures::login::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Subscribe via the endpoint @@ -21,10 +22,14 @@ async fn terminates_on_token_expiry() { let subscriber = fixtures::identity::logged_in(&app, &subscriber_creds, &fixtures::ancient()).await; - let get::Response(events) = - get::handler(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); + let get::Response(events) = get::handler( + State(app.clone()), + subscriber, + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); // Verify the resulting stream's behaviour @@ -56,6 +61,7 @@ async fn terminates_on_logout() { let app = fixtures::scratch_app().await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; let sender = fixtures::login::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; // Subscribe via the endpoint @@ -65,7 +71,7 @@ async fn terminates_on_logout() { State(app.clone()), subscriber.clone(), None, - Query::default(), + Query(get::QueryParams { resume_point }), ) .await .expect("subscribe never fails"); @@ -93,3 +99,53 @@ async fn terminates_on_logout() { .expect_none("end of stream") .await; } + +#[tokio::test] +async fn terminates_on_password_change() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app, &fixtures::now()).await; + let resume_point = fixtures::boot::resume_point(&app).await; + + // Subscribe via the endpoint + + let creds = fixtures::login::create_with_password(&app, &fixtures::now()).await; + let cookie = fixtures::cookie::logged_in(&app, &creds, &fixtures::now()).await; + let subscriber = fixtures::identity::from_cookie(&app, &cookie, &fixtures::now()).await; + + let get::Response(events) = get::handler( + State(app.clone()), + subscriber.clone(), + None, + Query(get::QueryParams { resume_point }), + ) + .await + .expect("subscribe never fails"); + + // Verify the resulting stream's behaviour + + let (_, password) = creds; + let to = fixtures::login::propose_password(); + app.tokens() + .change_password(&subscriber.login, &password, &to, &fixtures::now()) + .await + .expect("expiring tokens succeeds"); + + // These should not be delivered. + + let messages = [ + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + ]; + + events + .filter_map(fixtures::event::message) + .filter_map(fixtures::event::message::sent) + .filter(|event| future::ready(messages.iter().any(|message| &event.message == message))) + .next() + .expect_none("end of stream") + .await; +} diff --git a/src/event/sequence.rs b/src/event/sequence.rs index 9bc399b..77281c2 100644 --- a/src/event/sequence.rs +++ b/src/event/sequence.rs @@ -1,6 +1,5 @@ use std::fmt; -use super::ResumePoint; use crate::clock::DateTime; #[derive(Clone, Copy, Debug, Eq, PartialEq, serde::Serialize)] @@ -51,18 +50,18 @@ impl fmt::Display for Sequence { } impl Sequence { - pub fn up_to<E>(resume_point: ResumePoint) -> impl for<'e> Fn(&'e E) -> bool + pub fn up_to<E>(resume_point: Sequence) -> impl for<'e> Fn(&'e E) -> bool where E: Sequenced, { - move |event| resume_point.map_or(true, |resume_point| event.sequence() <= resume_point) + move |event| event.sequence() <= resume_point } - pub fn after<E>(resume_point: ResumePoint) -> impl for<'e> Fn(&'e E) -> bool + pub fn after<E>(resume_point: Sequence) -> impl for<'e> Fn(&'e E) -> bool where E: Sequenced, { - move |event| resume_point < Some(event.sequence()) + move |event| resume_point < event.sequence() } pub fn start_from<E>(resume_point: Self) -> impl for<'e> Fn(&'e E) -> bool diff --git a/src/invite/app.rs b/src/invite/app.rs index 176075f..d4e877a 100644 --- a/src/invite/app.rs +++ b/src/invite/app.rs @@ -5,8 +5,11 @@ use super::{repo::Provider as _, Id, Invite, Summary}; use crate::{ clock::DateTime, db::{Duplicate as _, NotFound as _}, - event::{repo::Provider as _, Broadcaster, Event}, - login::{repo::Provider as _, Login, Password}, + event::Broadcaster, + login::{ + create::{self, Create}, + Login, Password, + }, name::Name, token::{repo::Provider as _, Secret}, }; @@ -44,6 +47,8 @@ impl<'a> Invites<'a> { password: &Password, accepted_at: &DateTime, ) -> Result<(Login, Secret), AcceptError> { + let create = Create::begin(name, password, accepted_at); + let mut tx = self.db.begin().await?; let invite = tx .invites() @@ -55,23 +60,20 @@ impl<'a> Invites<'a> { // the invite. Final validation is in the next tx. tx.commit().await?; - let password_hash = password.hash()?; + let validated = create.validate()?; let mut tx = self.db.begin().await?; // If the invite has been deleted or accepted in the interim, this step will // catch it. tx.invites().accept(&invite).await?; - let created = tx.sequence().next(accepted_at).await?; - let login = tx - .logins() - .create(name, &password_hash, &created) + let stored = validated + .store(&mut tx) .await .duplicate(|| AcceptError::DuplicateLogin(name.clone()))?; - let secret = tx.tokens().issue(&login, accepted_at).await?; + let secret = tx.tokens().issue(stored.login(), accepted_at).await?; tx.commit().await?; - self.events - .broadcast(login.events().map(Event::from).collect::<Vec<_>>()); + let login = stored.publish(self.events); Ok((login.as_created(), secret)) } @@ -92,6 +94,8 @@ impl<'a> Invites<'a> { pub enum AcceptError { #[error("invite not found: {0}")] NotFound(Id), + #[error("invalid login name: {0}")] + InvalidName(Name), #[error("name in use: {0}")] DuplicateLogin(Name), #[error(transparent)] @@ -99,3 +103,12 @@ pub enum AcceptError { #[error(transparent)] PasswordHash(#[from] password_hash::Error), } + +impl From<create::Error> for AcceptError { + fn from(error: create::Error) -> Self { + match error { + create::Error::InvalidName(name) => Self::InvalidName(name), + create::Error::PasswordHash(error) => Self::PasswordHash(error), + } + } +} diff --git a/src/invite/routes/invite/post.rs b/src/invite/routes/invite/post.rs index 0dd8dba..bb68e07 100644 --- a/src/invite/routes/invite/post.rs +++ b/src/invite/routes/invite/post.rs @@ -36,7 +36,8 @@ pub struct Request { pub password: Password, } -#[derive(Debug)] +#[derive(Debug, thiserror::Error)] +#[error(transparent)] pub struct Error(pub app::AcceptError); impl IntoResponse for Error { @@ -44,6 +45,9 @@ impl IntoResponse for Error { let Self(error) = self; match error { app::AcceptError::NotFound(_) => NotFound(error).into_response(), + app::AcceptError::InvalidName(_) => { + (StatusCode::BAD_REQUEST, error.to_string()).into_response() + } app::AcceptError::DuplicateLogin(_) => { (StatusCode::CONFLICT, error.to_string()).into_response() } diff --git a/src/invite/routes/invite/test/post.rs b/src/invite/routes/invite/test/post.rs index 65ab61e..40e0580 100644 --- a/src/invite/routes/invite/test/post.rs +++ b/src/invite/routes/invite/test/post.rs @@ -206,3 +206,35 @@ async fn conflicting_name() { matches!(error, AcceptError::DuplicateLogin(error_name) if error_name == conflicting_name) ); } + +#[tokio::test] +async fn invalid_name() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let issuer = fixtures::login::create(&app, &fixtures::now()).await; + let invite = fixtures::invite::issue(&app, &issuer, &fixtures::now()).await; + + // Call the endpoint + + let name = fixtures::login::propose_invalid_name(); + let password = fixtures::login::propose_password(); + let identity = fixtures::cookie::not_logged_in(); + let request = post::Request { + name: name.clone(), + password: password.clone(), + }; + let post::Error(error) = post::handler( + State(app.clone()), + fixtures::now(), + identity, + Path(invite.id), + Json(request), + ) + .await + .expect_err("using an invalid name should fail"); + + // Verify the response + + assert!(matches!(error, AcceptError::InvalidName(error_name) if name == error_name)); +} diff --git a/src/login/app.rs b/src/login/app.rs index 2f5896f..f458561 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -1,65 +1,56 @@ use sqlx::sqlite::SqlitePool; -use super::repo::Provider as _; - -#[cfg(test)] -use super::{Login, Password}; -#[cfg(test)] -use crate::{ - clock::DateTime, - event::{repo::Provider as _, Broadcaster, Event}, - name::Name, +use super::{ + create::{self, Create}, + Login, Password, }; +use crate::{clock::DateTime, event::Broadcaster, name::Name}; pub struct Logins<'a> { db: &'a SqlitePool, - #[cfg(test)] events: &'a Broadcaster, } impl<'a> Logins<'a> { - #[cfg(not(test))] - pub const fn new(db: &'a SqlitePool) -> Self { - Self { db } - } - - #[cfg(test)] pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { Self { db, events } } - #[cfg(test)] pub async fn create( &self, name: &Name, password: &Password, created_at: &DateTime, ) -> Result<Login, CreateError> { - let password_hash = password.hash()?; + let create = Create::begin(name, password, created_at); + let validated = create.validate()?; let mut tx = self.db.begin().await?; - let created = tx.sequence().next(created_at).await?; - let login = tx.logins().create(name, &password_hash, &created).await?; + let stored = validated.store(&mut tx).await?; tx.commit().await?; - self.events - .broadcast(login.events().map(Event::from).collect::<Vec<_>>()); + let login = stored.publish(self.events); Ok(login.as_created()) } - - pub async fn recanonicalize(&self) -> Result<(), sqlx::Error> { - let mut tx = self.db.begin().await?; - tx.logins().recanonicalize().await?; - tx.commit().await?; - - Ok(()) - } } #[derive(Debug, thiserror::Error)] -#[error(transparent)] pub enum CreateError { - Database(#[from] sqlx::Error), + #[error("invalid login name: {0}")] + InvalidName(Name), + #[error(transparent)] PasswordHash(#[from] password_hash::Error), + #[error(transparent)] + Database(#[from] sqlx::Error), +} + +#[cfg(test)] +impl From<create::Error> for CreateError { + fn from(error: create::Error) -> Self { + match error { + create::Error::InvalidName(name) => Self::InvalidName(name), + create::Error::PasswordHash(error) => Self::PasswordHash(error), + } + } } diff --git a/src/login/create.rs b/src/login/create.rs new file mode 100644 index 0000000..693daaf --- /dev/null +++ b/src/login/create.rs @@ -0,0 +1,95 @@ +use sqlx::{sqlite::Sqlite, Transaction}; + +use super::{password::StoredHash, repo::Provider as _, validate, History, Password}; +use crate::{ + clock::DateTime, + event::{repo::Provider as _, Broadcaster, Event}, + name::Name, +}; + +pub struct Create<'a> { + name: &'a Name, + password: &'a Password, + created_at: &'a DateTime, +} + +impl<'a> Create<'a> { + #[must_use = "dropping a login creation attempt is likely a mistake"] + pub fn begin(name: &'a Name, password: &'a Password, created_at: &'a DateTime) -> Self { + Self { + name, + password, + created_at, + } + } + + #[must_use = "dropping a login creation attempt is likely a mistake"] + pub fn validate(self) -> Result<Validated<'a>, Error> { + let Self { + name, + password, + created_at, + } = self; + + if !validate::name(name) { + return Err(Error::InvalidName(name.clone())); + } + + let password_hash = password.hash()?; + + Ok(Validated { + name, + password_hash, + created_at, + }) + } +} + +pub struct Validated<'a> { + name: &'a Name, + password_hash: StoredHash, + created_at: &'a DateTime, +} + +impl<'a> Validated<'a> { + #[must_use = "dropping a login creation attempt is likely a mistake"] + pub async fn store<'c>(self, tx: &mut Transaction<'c, Sqlite>) -> Result<Stored, sqlx::Error> { + let Self { + name, + password_hash, + created_at, + } = self; + + let created = tx.sequence().next(created_at).await?; + let login = tx.logins().create(name, &password_hash, &created).await?; + + Ok(Stored { login }) + } +} + +pub struct Stored { + login: History, +} + +impl Stored { + #[must_use = "dropping a login creation attempt is likely a mistake"] + pub fn publish(self, events: &Broadcaster) -> History { + let Self { login } = self; + + events.broadcast(login.events().map(Event::from).collect::<Vec<_>>()); + + login + } + + pub fn login(&self) -> &History { + &self.login + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("invalid login name: {0}")] + InvalidName(Name), + #[error(transparent)] + PasswordHash(#[from] password_hash::Error), +} diff --git a/src/login/history.rs b/src/login/history.rs index daad579..8161b0b 100644 --- a/src/login/history.rs +++ b/src/login/history.rs @@ -2,7 +2,7 @@ use super::{ event::{Created, Event}, Id, Login, }; -use crate::event::{Instant, ResumePoint, Sequence}; +use crate::event::{Instant, Sequence}; #[derive(Clone, Debug, Eq, PartialEq)] pub struct History { @@ -24,9 +24,9 @@ impl History { self.login.clone() } - pub fn as_of(&self, resume_point: impl Into<ResumePoint>) -> Option<Login> { + pub fn as_of(&self, resume_point: Sequence) -> Option<Login> { self.events() - .filter(Sequence::up_to(resume_point.into())) + .filter(Sequence::up_to(resume_point)) .collect() } diff --git a/src/login/mod.rs b/src/login/mod.rs index 279e9a6..006fa0c 100644 --- a/src/login/mod.rs +++ b/src/login/mod.rs @@ -1,4 +1,6 @@ +#[cfg(test)] pub mod app; +pub mod create; pub mod event; mod history; mod id; @@ -6,6 +8,7 @@ pub mod password; pub mod repo; mod routes; mod snapshot; +mod validate; pub use self::{ event::Event, history::History, id::Id, password::Password, routes::router, snapshot::Login, diff --git a/src/login/repo.rs b/src/login/repo.rs index 611edd6..1c63a4b 100644 --- a/src/login/repo.rs +++ b/src/login/repo.rs @@ -3,7 +3,7 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ clock::DateTime, - event::{Instant, ResumePoint, Sequence}, + event::{Instant, Sequence}, login::{password::StoredHash, History, Id, Login}, name::{self, Name}, }; @@ -58,7 +58,30 @@ impl<'c> Logins<'c> { Ok(login) } - pub async fn all(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, LoadError> { + pub async fn set_password( + &mut self, + login: &History, + to: &StoredHash, + ) -> Result<(), sqlx::Error> { + let login = login.id(); + + sqlx::query_scalar!( + r#" + update login + set password_hash = $1 + where id = $2 + returning id as "id: Id" + "#, + to, + login, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(()) + } + + pub async fn all(&mut self, resume_at: Sequence) -> Result<Vec<History>, LoadError> { let logins = sqlx::query!( r#" select @@ -68,7 +91,7 @@ impl<'c> Logins<'c> { created_sequence as "created_sequence: Sequence", created_at as "created_at: DateTime" from login - where coalesce(created_sequence <= $1, true) + where created_sequence <= $1 order by canonical_name "#, resume_at, @@ -90,7 +113,7 @@ impl<'c> Logins<'c> { Ok(logins) } - pub async fn replay(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, LoadError> { + pub async fn replay(&mut self, resume_at: Sequence) -> Result<Vec<History>, LoadError> { let logins = sqlx::query!( r#" select @@ -100,7 +123,7 @@ impl<'c> Logins<'c> { created_sequence as "created_sequence: Sequence", created_at as "created_at: DateTime" from login - where coalesce(login.created_sequence > $1, true) + where login.created_sequence > $1 "#, resume_at, ) @@ -120,38 +143,6 @@ impl<'c> Logins<'c> { Ok(logins) } - - pub async fn recanonicalize(&mut self) -> Result<(), sqlx::Error> { - let logins = sqlx::query!( - r#" - select - id as "id: Id", - display_name as "display_name: String" - from login - "#, - ) - .fetch_all(&mut *self.0) - .await?; - - for login in logins { - let name = Name::from(login.display_name); - let canonical_name = name.canonical(); - - sqlx::query!( - r#" - update login - set canonical_name = $1 - where id = $2 - "#, - canonical_name, - login.id, - ) - .execute(&mut *self.0) - .await?; - } - - Ok(()) - } } #[derive(Debug, thiserror::Error)] diff --git a/src/login/routes/mod.rs b/src/login/routes/mod.rs index 8cb8852..bbd0c3f 100644 --- a/src/login/routes/mod.rs +++ b/src/login/routes/mod.rs @@ -4,9 +4,11 @@ use crate::app::App; mod login; mod logout; +mod password; pub fn router() -> Router<App> { Router::new() + .route("/api/password", post(password::post::handler)) .route("/api/auth/login", post(login::post::handler)) .route("/api/auth/logout", post(logout::post::handler)) } diff --git a/src/login/routes/password/mod.rs b/src/login/routes/password/mod.rs new file mode 100644 index 0000000..36b384e --- /dev/null +++ b/src/login/routes/password/mod.rs @@ -0,0 +1,4 @@ +pub mod post; + +#[cfg(test)] +mod test; diff --git a/src/login/routes/password/post.rs b/src/login/routes/password/post.rs new file mode 100644 index 0000000..4723754 --- /dev/null +++ b/src/login/routes/password/post.rs @@ -0,0 +1,54 @@ +use axum::{ + extract::{Json, State}, + http::StatusCode, + response::{IntoResponse, Response}, +}; + +use crate::{ + app::App, + clock::RequestedAt, + error::Internal, + login::{Login, Password}, + token::{ + app, + extract::{Identity, IdentityCookie}, + }, +}; + +pub async fn handler( + State(app): State<App>, + RequestedAt(now): RequestedAt, + identity: Identity, + cookie: IdentityCookie, + Json(request): Json<Request>, +) -> Result<(IdentityCookie, Json<Login>), Error> { + let (login, secret) = app + .tokens() + .change_password(&identity.login, &request.password, &request.to, &now) + .await + .map_err(Error)?; + let cookie = cookie.set(secret); + Ok((cookie, Json(login))) +} + +#[derive(serde::Deserialize)] +pub struct Request { + pub password: Password, + pub to: Password, +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct Error(#[from] pub app::LoginError); + +impl IntoResponse for Error { + fn into_response(self) -> Response { + let Self(error) = self; + match error { + app::LoginError::Rejected => { + (StatusCode::BAD_REQUEST, "invalid name or password").into_response() + } + other => Internal::from(other).into_response(), + } + } +} diff --git a/src/login/routes/password/test.rs b/src/login/routes/password/test.rs new file mode 100644 index 0000000..c1974bf --- /dev/null +++ b/src/login/routes/password/test.rs @@ -0,0 +1,68 @@ +use axum::extract::{Json, State}; + +use super::post; +use crate::{ + test::fixtures, + token::app::{LoginError, ValidateError}, +}; + +#[tokio::test] +async fn password_change() { + // Set up the environment + let app = fixtures::scratch_app().await; + let creds = fixtures::login::create_with_password(&app, &fixtures::now()).await; + let cookie = fixtures::cookie::logged_in(&app, &creds, &fixtures::now()).await; + let identity = fixtures::identity::from_cookie(&app, &cookie, &fixtures::now()).await; + + // Call the endpoint + let (name, password) = creds; + let to = fixtures::login::propose_password(); + let request = post::Request { + password: password.clone(), + to: to.clone(), + }; + let (new_cookie, Json(response)) = post::handler( + State(app.clone()), + fixtures::now(), + identity.clone(), + cookie.clone(), + Json(request), + ) + .await + .expect("changing passwords succeeds"); + + // Verify that we have a new session + assert_ne!(cookie.secret(), new_cookie.secret()); + + // Verify that we're still ourselves + assert_eq!(identity.login, response); + + // Verify that our original token is no longer valid + let validate_err = app + .tokens() + .validate( + &cookie + .secret() + .expect("original identity cookie has a secret"), + &fixtures::now(), + ) + .await + .expect_err("validating the original identity secret should fail"); + assert!(matches!(validate_err, ValidateError::InvalidToken)); + + // Verify that our original password is no longer valid + let login_err = app + .tokens() + .login(&name, &password, &fixtures::now()) + .await + .expect_err("logging in with the original password should fail"); + assert!(matches!(login_err, LoginError::Rejected)); + + // Verify that our new password is valid + let (login, _) = app + .tokens() + .login(&name, &to, &fixtures::now()) + .await + .expect("logging in with the new password should succeed"); + assert_eq!(identity.login, login); +} diff --git a/src/login/validate.rs b/src/login/validate.rs new file mode 100644 index 0000000..0c97293 --- /dev/null +++ b/src/login/validate.rs @@ -0,0 +1,23 @@ +use unicode_segmentation::UnicodeSegmentation as _; + +use crate::name::Name; + +// Picked out of a hat. The power of two is not meaningful. +const NAME_TOO_LONG: usize = 64; + +pub fn name(name: &Name) -> bool { + let display = name.display(); + + [ + display.graphemes(true).count() < NAME_TOO_LONG, + display.chars().all(|ch| !ch.is_control()), + display.chars().next().is_some_and(|c| !c.is_whitespace()), + display.chars().last().is_some_and(|c| !c.is_whitespace()), + display + .chars() + .zip(display.chars().skip(1)) + .all(|(a, b)| !(a.is_whitespace() && b.is_whitespace())), + ] + .into_iter() + .all(|value| value) +} diff --git a/src/bin/hi.rs b/src/main.rs index d0830ff..d0830ff 100644 --- a/src/bin/hi.rs +++ b/src/main.rs diff --git a/src/message/app.rs b/src/message/app.rs index eed6ba4..137a27d 100644 --- a/src/message/app.rs +++ b/src/message/app.rs @@ -45,16 +45,24 @@ impl<'a> Messages<'a> { Ok(message.as_sent()) } - pub async fn delete(&self, message: &Id, deleted_at: &DateTime) -> Result<(), DeleteError> { + pub async fn delete( + &self, + deleted_by: &Login, + message: &Id, + deleted_at: &DateTime, + ) -> Result<(), DeleteError> { let mut tx = self.db.begin().await?; let message = tx .messages() .by_id(message) .await .not_found(|| DeleteError::NotFound(message.clone()))?; - message + let snapshot = message .as_snapshot() .ok_or_else(|| DeleteError::Deleted(message.id().clone()))?; + if snapshot.sender != deleted_by.id { + return Err(DeleteError::NotSender(deleted_by.clone())); + } let deleted = tx.sequence().next(deleted_at).await?; let message = tx.messages().delete(&message, &deleted).await?; @@ -138,6 +146,8 @@ impl From<channel::repo::LoadError> for SendError { pub enum DeleteError { #[error("message {0} not found")] NotFound(Id), + #[error("login {} not the message's sender", .0.id)] + NotSender(Login), #[error("message {0} deleted")] Deleted(Id), #[error(transparent)] diff --git a/src/message/history.rs b/src/message/history.rs index 0424d0d..ed8f5df 100644 --- a/src/message/history.rs +++ b/src/message/history.rs @@ -1,8 +1,10 @@ +use itertools::Itertools as _; + use super::{ event::{Deleted, Event, Sent}, Id, Message, }; -use crate::event::{Instant, ResumePoint, Sequence}; +use crate::event::{Instant, Sequence}; #[derive(Clone, Debug, Eq, PartialEq)] pub struct History { @@ -25,9 +27,9 @@ impl History { self.message.clone() } - pub fn as_of(&self, resume_point: impl Into<ResumePoint>) -> Option<Message> { + pub fn as_of(&self, resume_point: Sequence) -> Option<Message> { self.events() - .filter(Sequence::up_to(resume_point.into())) + .filter(Sequence::up_to(resume_point)) .collect() } @@ -57,6 +59,8 @@ impl History { } pub fn events(&self) -> impl Iterator<Item = Event> { - [self.sent()].into_iter().chain(self.deleted()) + [self.sent()] + .into_iter() + .merge_by(self.deleted(), Sequence::merge) } } diff --git a/src/message/repo.rs b/src/message/repo.rs index c8ceceb..14f8eaf 100644 --- a/src/message/repo.rs +++ b/src/message/repo.rs @@ -4,7 +4,7 @@ use super::{snapshot::Message, Body, History, Id}; use crate::{ channel, clock::DateTime, - event::{Instant, ResumePoint, Sequence}, + event::{Instant, Sequence}, login::{self, Login}, }; @@ -34,8 +34,8 @@ impl<'c> Messages<'c> { let message = sqlx::query!( r#" insert into message - (id, channel, sender, sent_at, sent_sequence, body) - values ($1, $2, $3, $4, $5, $6) + (id, channel, sender, sent_at, sent_sequence, body, last_sequence) + values ($1, $2, $3, $4, $5, $6, $7) returning id as "id: Id", channel as "channel: channel::Id", @@ -50,6 +50,7 @@ impl<'c> Messages<'c> { sent.at, sent.sequence, body, + sent.sequence, ) .map(|row| History { message: Message { @@ -106,22 +107,22 @@ impl<'c> Messages<'c> { Ok(messages) } - pub async fn all(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, sqlx::Error> { + pub async fn all(&mut self, resume_at: Sequence) -> Result<Vec<History>, sqlx::Error> { let messages = sqlx::query!( r#" select message.channel as "channel: channel::Id", message.sender as "sender: login::Id", - id as "id: Id", + message.id as "id: Id", message.body as "body: Body", message.sent_at as "sent_at: DateTime", message.sent_sequence as "sent_sequence: Sequence", - deleted.deleted_at as "deleted_at: DateTime", - deleted.deleted_sequence as "deleted_sequence: Sequence" + deleted.deleted_at as "deleted_at?: DateTime", + deleted.deleted_sequence as "deleted_sequence?: Sequence" from message left join message_deleted as deleted using (id) - where coalesce(message.sent_sequence <= $2, true) + where message.sent_sequence <= $1 order by message.sent_sequence "#, resume_at, @@ -205,12 +206,14 @@ impl<'c> Messages<'c> { sqlx::query!( r#" update message - set body = '' - where id = $1 + set body = '', last_sequence = max(last_sequence, $1) + where id = $2 + returning id as "id: Id" "#, + deleted.sequence, id, ) - .execute(&mut *self.0) + .fetch_one(&mut *self.0) .await?; let message = self.by_id(id).await?; @@ -282,7 +285,7 @@ impl<'c> Messages<'c> { Ok(messages) } - pub async fn replay(&mut self, resume_at: ResumePoint) -> Result<Vec<History>, sqlx::Error> { + pub async fn replay(&mut self, resume_at: Sequence) -> Result<Vec<History>, sqlx::Error> { let messages = sqlx::query!( r#" select @@ -292,12 +295,12 @@ impl<'c> Messages<'c> { message.sent_at as "sent_at: DateTime", message.sent_sequence as "sent_sequence: Sequence", message.body as "body: Body", - deleted.deleted_at as "deleted_at: DateTime", - deleted.deleted_sequence as "deleted_sequence: Sequence" + deleted.deleted_at as "deleted_at?: DateTime", + deleted.deleted_sequence as "deleted_sequence?: Sequence" from message left join message_deleted as deleted using (id) - where coalesce(message.sent_sequence > $1, true) + where message.last_sequence > $1 "#, resume_at, ) diff --git a/src/message/routes/message/mod.rs b/src/message/routes/message/mod.rs index 45a7e9d..e92f556 100644 --- a/src/message/routes/message/mod.rs +++ b/src/message/routes/message/mod.rs @@ -20,9 +20,11 @@ pub mod delete { State(app): State<App>, Path(message): Path<message::Id>, RequestedAt(deleted_at): RequestedAt, - _: Identity, + identity: Identity, ) -> Result<Response, Error> { - app.messages().delete(&message, &deleted_at).await?; + app.messages() + .delete(&identity.login, &message, &deleted_at) + .await?; Ok(Response { id: message }) } @@ -47,6 +49,9 @@ pub mod delete { let Self(error) = self; #[allow(clippy::match_wildcard_for_single_variants)] match error { + DeleteError::NotSender(_) => { + (StatusCode::FORBIDDEN, error.to_string()).into_response() + } DeleteError::NotFound(_) | DeleteError::Deleted(_) => { NotFound(error).into_response() } diff --git a/src/message/routes/message/test.rs b/src/message/routes/message/test.rs index ae89506..5178ab5 100644 --- a/src/message/routes/message/test.rs +++ b/src/message/routes/message/test.rs @@ -8,18 +8,17 @@ pub async fn delete_message() { // Set up the environment let app = fixtures::scratch_app().await; - let sender = fixtures::login::create(&app, &fixtures::now()).await; + let sender = fixtures::identity::create(&app, &fixtures::now()).await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; + let message = fixtures::message::send(&app, &channel, &sender.login, &fixtures::now()).await; // Send the request - let deleter = fixtures::identity::create(&app, &fixtures::now()).await; let response = delete::handler( State(app.clone()), Path(message.id.clone()), fixtures::now(), - deleter, + sender, ) .await .expect("deleting a valid message succeeds"); @@ -68,7 +67,7 @@ pub async fn delete_deleted() { let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; app.messages() - .delete(&message.id, &fixtures::now()) + .delete(&sender, &message.id, &fixtures::now()) .await .expect("deleting a recently-sent message succeeds"); @@ -155,3 +154,31 @@ pub async fn delete_purged() { assert!(matches!(error, app::DeleteError::NotFound(id) if id == message.id)); } + +#[tokio::test] +pub async fn delete_not_sender() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let sender = fixtures::login::create(&app, &fixtures::now()).await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; + + // Send the request + + let deleter = fixtures::identity::create(&app, &fixtures::now()).await; + let delete::Error(error) = delete::handler( + State(app.clone()), + Path(message.id.clone()), + fixtures::now(), + deleter.clone(), + ) + .await + .expect_err("deleting a message someone else sent fails"); + + // Verify the response + + assert!( + matches!(error, app::DeleteError::NotSender(error_sender) if deleter.login == error_sender) + ); +} diff --git a/src/setup/app.rs b/src/setup/app.rs index 030b5f6..c1f7b69 100644 --- a/src/setup/app.rs +++ b/src/setup/app.rs @@ -3,8 +3,11 @@ use sqlx::sqlite::SqlitePool; use super::repo::Provider as _; use crate::{ clock::DateTime, - event::{repo::Provider as _, Broadcaster, Event}, - login::{repo::Provider as _, Login, Password}, + event::Broadcaster, + login::{ + create::{self, Create}, + Login, Password, + }, name::Name, token::{repo::Provider as _, Secret}, }; @@ -25,20 +28,20 @@ impl<'a> Setup<'a> { password: &Password, created_at: &DateTime, ) -> Result<(Login, Secret), Error> { - let password_hash = password.hash()?; + let create = Create::begin(name, password, created_at); + + let validated = create.validate()?; let mut tx = self.db.begin().await?; - let login = if tx.setup().completed().await? { + let stored = if tx.setup().completed().await? { Err(Error::SetupCompleted)? } else { - let created = tx.sequence().next(created_at).await?; - tx.logins().create(name, &password_hash, &created).await? + validated.store(&mut tx).await? }; - let secret = tx.tokens().issue(&login, created_at).await?; + let secret = tx.tokens().issue(stored.login(), created_at).await?; tx.commit().await?; - self.events - .broadcast(login.events().map(Event::from).collect::<Vec<_>>()); + let login = stored.publish(self.events); Ok((login.as_created(), secret)) } @@ -56,8 +59,19 @@ impl<'a> Setup<'a> { pub enum Error { #[error("initial setup previously completed")] SetupCompleted, + #[error("invalid login name: {0}")] + InvalidName(Name), #[error(transparent)] Database(#[from] sqlx::Error), #[error(transparent)] PasswordHash(#[from] password_hash::Error), } + +impl From<create::Error> for Error { + fn from(error: create::Error) -> Self { + match error { + create::Error::InvalidName(name) => Self::InvalidName(name), + create::Error::PasswordHash(error) => Self::PasswordHash(error), + } + } +} diff --git a/src/setup/routes/post.rs b/src/setup/routes/post.rs index f7b256e..2a46b04 100644 --- a/src/setup/routes/post.rs +++ b/src/setup/routes/post.rs @@ -42,6 +42,9 @@ impl IntoResponse for Error { fn into_response(self) -> Response { let Self(error) = self; match error { + app::Error::InvalidName(_) => { + (StatusCode::BAD_REQUEST, error.to_string()).into_response() + } app::Error::SetupCompleted => (StatusCode::CONFLICT, error.to_string()).into_response(), other => Internal::from(other).into_response(), } diff --git a/src/setup/routes/test.rs b/src/setup/routes/test.rs index f7562ae..5794b78 100644 --- a/src/setup/routes/test.rs +++ b/src/setup/routes/test.rs @@ -67,3 +67,28 @@ async fn login_exists() { assert!(matches!(error, app::Error::SetupCompleted)); } + +#[tokio::test] +async fn invalid_name() { + // Set up the environment + + let app = fixtures::scratch_app().await; + + // Call the endpoint + + let name = fixtures::login::propose_invalid_name(); + let password = fixtures::login::propose_password(); + let identity = fixtures::cookie::not_logged_in(); + let request = post::Request { + name: name.clone(), + password: password.clone(), + }; + let post::Error(error) = + post::handler(State(app.clone()), fixtures::now(), identity, Json(request)) + .await + .expect_err("setup with an invalid name fails"); + + // Verify the response + + assert!(matches!(error, app::Error::InvalidName(error_name) if name == error_name)); +} diff --git a/src/test/fixtures/boot.rs b/src/test/fixtures/boot.rs new file mode 100644 index 0000000..120726f --- /dev/null +++ b/src/test/fixtures/boot.rs @@ -0,0 +1,9 @@ +use crate::{app::App, event::Sequence}; + +pub async fn resume_point(app: &App) -> Sequence { + app.boot() + .snapshot() + .await + .expect("boot always succeeds") + .resume_point +} diff --git a/src/test/fixtures/channel.rs b/src/test/fixtures/channel.rs index 0c6480b..98048f2 100644 --- a/src/test/fixtures/channel.rs +++ b/src/test/fixtures/channel.rs @@ -1,6 +1,7 @@ use faker_rand::{ en_us::{addresses::CityName, names::FullName}, faker_impl_from_templates, + lorem::Paragraphs, }; use rand; @@ -23,6 +24,10 @@ pub fn propose() -> Name { rand::random::<NameTemplate>().to_string().into() } +pub fn propose_invalid_name() -> Name { + rand::random::<Paragraphs>().to_string().into() +} + struct NameTemplate(String); faker_impl_from_templates! { NameTemplate; "{} {}", CityName, FullName; diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs index e438f2b..ffc44c6 100644 --- a/src/test/fixtures/identity.rs +++ b/src/test/fixtures/identity.rs @@ -15,11 +15,15 @@ pub async fn create(app: &App, created_at: &RequestedAt) -> Identity { logged_in(app, &credentials, created_at).await } -pub async fn from_cookie(app: &App, token: &IdentityCookie, issued_at: &RequestedAt) -> Identity { - let secret = token.secret().expect("identity token has a secret"); +pub async fn from_cookie( + app: &App, + cookie: &IdentityCookie, + validated_at: &RequestedAt, +) -> Identity { + let secret = cookie.secret().expect("identity token has a secret"); let (token, login) = app .tokens() - .validate(&secret, issued_at) + .validate(&secret, validated_at) .await .expect("always validates newly-issued secret"); diff --git a/src/test/fixtures/login.rs b/src/test/fixtures/login.rs index e308289..86e3e39 100644 --- a/src/test/fixtures/login.rs +++ b/src/test/fixtures/login.rs @@ -1,4 +1,4 @@ -use faker_rand::en_us::internet; +use faker_rand::{en_us::internet, lorem::Paragraphs}; use uuid::Uuid; use crate::{ @@ -38,6 +38,10 @@ pub fn propose() -> (Name, Password) { (propose_name(), propose_password()) } +pub fn propose_invalid_name() -> Name { + rand::random::<Paragraphs>().to_string().into() +} + fn propose_name() -> Name { rand::random::<internet::Username>().to_string().into() } diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs index 2b7b6af..470b31a 100644 --- a/src/test/fixtures/mod.rs +++ b/src/test/fixtures/mod.rs @@ -2,6 +2,7 @@ use chrono::{TimeDelta, Utc}; use crate::{app::App, clock::RequestedAt, db}; +pub mod boot; pub mod channel; pub mod cookie; pub mod event; diff --git a/src/token/app.rs b/src/token/app.rs index c19d6a0..5c0aeb0 100644 --- a/src/token/app.rs +++ b/src/token/app.rs @@ -13,7 +13,7 @@ use super::{ use crate::{ clock::DateTime, db::NotFound as _, - login::{Login, Password}, + login::{repo::Provider as _, Login, Password}, name::{self, Name}, }; @@ -61,6 +61,47 @@ impl<'a> Tokens<'a> { Ok((snapshot, token)) } + pub async fn change_password( + &self, + login: &Login, + password: &Password, + to: &Password, + changed_at: &DateTime, + ) -> Result<(Login, Secret), LoginError> { + let mut tx = self.db.begin().await?; + let (login, stored_hash) = tx + .auth() + .for_login(login) + .await + .optional()? + .ok_or(LoginError::Rejected)?; + // Split the transaction here to avoid holding the tx open (potentially blocking + // other writes) while we do the fairly expensive task of verifying the + // password. It's okay if the token issuance transaction happens some notional + // amount of time after retrieving the login, as inserting the token will fail + // if the account is deleted during that time. + tx.commit().await?; + + if !stored_hash.verify(password)? { + return Err(LoginError::Rejected); + } + + let snapshot = login.as_snapshot().ok_or(LoginError::Rejected)?; + let to_hash = to.hash()?; + + let mut tx = self.db.begin().await?; + let tokens = tx.tokens().revoke_all(&login).await?; + tx.logins().set_password(&login, &to_hash).await?; + let secret = tx.tokens().issue(&login, changed_at).await?; + tx.commit().await?; + + for event in tokens.into_iter().map(TokenEvent::Revoked) { + self.token_events.broadcast(event); + } + + Ok((snapshot, secret)) + } + pub async fn validate( &self, secret: &Secret, diff --git a/src/token/repo/auth.rs b/src/token/repo/auth.rs index bdc4c33..b51db8c 100644 --- a/src/token/repo/auth.rs +++ b/src/token/repo/auth.rs @@ -50,6 +50,35 @@ impl<'t> Auth<'t> { Ok((login, row.password_hash)) } + + pub async fn for_login(&mut self, login: &Login) -> Result<(History, StoredHash), LoadError> { + let row = sqlx::query!( + r#" + select + id as "id: login::Id", + display_name as "display_name: String", + canonical_name as "canonical_name: String", + created_sequence as "created_sequence: Sequence", + created_at as "created_at: DateTime", + password_hash as "password_hash: StoredHash" + from login + where id = $1 + "#, + login.id, + ) + .fetch_one(&mut *self.0) + .await?; + + let login = History { + login: Login { + id: row.id, + name: Name::new(row.display_name, row.canonical_name)?, + }, + created: Instant::new(row.created_at, row.created_sequence), + }; + + Ok((login, row.password_hash)) + } } #[derive(Debug, thiserror::Error)] diff --git a/src/token/repo/token.rs b/src/token/repo/token.rs index 35ea385..33b89d5 100644 --- a/src/token/repo/token.rs +++ b/src/token/repo/token.rs @@ -84,6 +84,24 @@ impl<'c> Tokens<'c> { Ok(()) } + // Revoke tokens for a login + pub async fn revoke_all(&mut self, login: &login::History) -> Result<Vec<Id>, sqlx::Error> { + let login = login.id(); + let tokens = sqlx::query_scalar!( + r#" + delete + from token + where login = $1 + returning id as "id: Id" + "#, + login, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(tokens) + } + // Expire and delete all tokens that haven't been used more recently than // `expire_at`. pub async fn expire(&mut self, expire_at: &DateTime) -> Result<Vec<Id>, sqlx::Error> { diff --git a/src/token/secret.rs b/src/token/secret.rs index 28c93bb..8f70646 100644 --- a/src/token/secret.rs +++ b/src/token/secret.rs @@ -1,6 +1,6 @@ use std::fmt; -#[derive(sqlx::Type)] +#[derive(PartialEq, Eq, sqlx::Type)] #[sqlx(transparent)] pub struct Secret(String); diff --git a/src/ui/routes/me.rs b/src/ui/routes/me.rs new file mode 100644 index 0000000..f1f118f --- /dev/null +++ b/src/ui/routes/me.rs @@ -0,0 +1,32 @@ +pub mod get { + use axum::response::{self, IntoResponse, Redirect}; + + use crate::{ + error::Internal, + token::extract::Identity, + ui::assets::{Asset, Assets}, + }; + + pub async fn handler(identity: Option<Identity>) -> Result<Asset, Error> { + let _ = identity.ok_or(Error::NotLoggedIn)?; + + Assets::index().map_err(Error::Internal) + } + + #[derive(Debug, thiserror::Error)] + pub enum Error { + #[error("not logged in")] + NotLoggedIn, + #[error("{0}")] + Internal(Internal), + } + + impl IntoResponse for Error { + fn into_response(self) -> response::Response { + match self { + Self::NotLoggedIn => Redirect::temporary("/login").into_response(), + Self::Internal(error) => error.into_response(), + } + } + } +} diff --git a/src/ui/routes/mod.rs b/src/ui/routes/mod.rs index 72d9a4a..48b3f90 100644 --- a/src/ui/routes/mod.rs +++ b/src/ui/routes/mod.rs @@ -6,6 +6,7 @@ mod ch; mod get; mod invite; mod login; +mod me; mod path; mod setup; @@ -16,6 +17,7 @@ pub fn router(app: &App) -> Router<App> { .route("/setup", get(setup::get::handler)), Router::new() .route("/", get(get::handler)) + .route("/me", get(me::get::handler)) .route("/login", get(login::get::handler)) .route("/ch/:channel", get(ch::channel::get::handler)) .route("/invite/:invite", get(invite::invite::get::handler)) |
