From f878f0b5eaa44e8ee8d67cbfd706926ff2119113 Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Tue, 1 Oct 2024 23:57:22 -0400 Subject: Organize IDs into top-level namespaces. (This is part of a larger reorganization.) --- src/channel/mod.rs | 3 +++ 1 file changed, 3 insertions(+) (limited to 'src/channel/mod.rs') diff --git a/src/channel/mod.rs b/src/channel/mod.rs index 9f79dbb..3115e98 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -1,4 +1,7 @@ pub mod app; +mod id; mod routes; pub use self::routes::router; + +pub use self::id::Id; -- cgit v1.2.3 From 357116366c1307bedaac6a3dfe9c5ed8e0e0c210 Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Wed, 2 Oct 2024 00:41:25 -0400 Subject: First pass on reorganizing the backend. This is primarily renames and repackagings. --- ...4db639b7a4a21210ebf0b7fba97cd016ff6ab4d769.json | 20 + ...884970c3eb93fc734d979e6b8e78cd7d0b6dd0b669.json | 20 - src/app.rs | 2 +- src/channel/app.rs | 9 +- src/channel/mod.rs | 14 +- src/channel/routes.rs | 6 +- src/channel/routes/test/on_create.rs | 2 +- src/channel/routes/test/on_send.rs | 2 +- src/cli.rs | 4 +- src/event/app.rs | 137 +++++++ src/event/broadcaster.rs | 3 + src/event/extract.rs | 85 ++++ src/event/mod.rs | 9 + src/event/repo/message.rs | 188 +++++++++ src/event/repo/mod.rs | 1 + src/event/routes.rs | 93 +++++ src/event/routes/test.rs | 439 +++++++++++++++++++++ src/event/sequence.rs | 24 ++ src/event/types.rs | 97 +++++ src/events/app.rs | 140 ------- src/events/broadcaster.rs | 3 - src/events/extract.rs | 85 ---- src/events/mod.rs | 8 - src/events/repo/message.rs | 188 --------- src/events/repo/mod.rs | 1 - src/events/routes.rs | 92 ----- src/events/routes/test.rs | 439 --------------------- src/events/types.rs | 96 ----- src/lib.rs | 4 +- src/login/app.rs | 17 +- src/login/extract.rs | 181 +-------- src/login/mod.rs | 18 +- src/login/password.rs | 58 +++ src/login/repo/auth.rs | 2 +- src/login/routes.rs | 6 +- src/login/token/id.rs | 27 -- src/login/token/mod.rs | 3 - src/login/types.rs | 2 +- src/message/mod.rs | 6 + src/password.rs | 58 --- src/repo/channel.rs | 15 +- src/repo/login.rs | 50 +++ src/repo/login/extract.rs | 15 - src/repo/login/mod.rs | 4 - src/repo/login/store.rs | 63 --- src/repo/message.rs | 7 - src/repo/mod.rs | 1 - src/repo/sequence.rs | 27 +- src/repo/token.rs | 10 +- src/test/fixtures/channel.rs | 2 +- src/test/fixtures/filter.rs | 2 +- src/test/fixtures/identity.rs | 9 +- src/test/fixtures/login.rs | 3 +- src/test/fixtures/message.rs | 7 +- src/token/extract/identity.rs | 75 ++++ src/token/extract/identity_token.rs | 94 +++++ src/token/extract/mod.rs | 4 + src/token/id.rs | 27 ++ src/token/mod.rs | 5 + src/token/secret.rs | 27 ++ 60 files changed, 1521 insertions(+), 1515 deletions(-) create mode 100644 .sqlx/query-8b474c8ed7859f745888644db639b7a4a21210ebf0b7fba97cd016ff6ab4d769.json delete mode 100644 .sqlx/query-e0deb4dfaffe4527ad630c884970c3eb93fc734d979e6b8e78cd7d0b6dd0b669.json create mode 100644 src/event/app.rs create mode 100644 src/event/broadcaster.rs create mode 100644 src/event/extract.rs create mode 100644 src/event/mod.rs create mode 100644 src/event/repo/message.rs create mode 100644 src/event/repo/mod.rs create mode 100644 src/event/routes.rs create mode 100644 src/event/routes/test.rs create mode 100644 src/event/sequence.rs create mode 100644 src/event/types.rs delete mode 100644 src/events/app.rs delete mode 100644 src/events/broadcaster.rs delete mode 100644 src/events/extract.rs delete mode 100644 src/events/mod.rs delete mode 100644 src/events/repo/message.rs delete mode 100644 src/events/repo/mod.rs delete mode 100644 src/events/routes.rs delete mode 100644 src/events/routes/test.rs delete mode 100644 src/events/types.rs create mode 100644 src/login/password.rs delete mode 100644 src/login/token/id.rs delete mode 100644 src/login/token/mod.rs delete mode 100644 src/password.rs create mode 100644 src/repo/login.rs delete mode 100644 src/repo/login/extract.rs delete mode 100644 src/repo/login/mod.rs delete mode 100644 src/repo/login/store.rs delete mode 100644 src/repo/message.rs create mode 100644 src/token/extract/identity.rs create mode 100644 src/token/extract/identity_token.rs create mode 100644 src/token/extract/mod.rs create mode 100644 src/token/id.rs create mode 100644 src/token/mod.rs create mode 100644 src/token/secret.rs (limited to 'src/channel/mod.rs') diff --git a/.sqlx/query-8b474c8ed7859f745888644db639b7a4a21210ebf0b7fba97cd016ff6ab4d769.json b/.sqlx/query-8b474c8ed7859f745888644db639b7a4a21210ebf0b7fba97cd016ff6ab4d769.json new file mode 100644 index 0000000..b433e4c --- /dev/null +++ b/.sqlx/query-8b474c8ed7859f745888644db639b7a4a21210ebf0b7fba97cd016ff6ab4d769.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n insert\n into token (id, secret, login, issued_at, last_used_at)\n values ($1, $2, $3, $4, $4)\n returning secret as \"secret!: Secret\"\n ", + "describe": { + "columns": [ + { + "name": "secret!: Secret", + "ordinal": 0, + "type_info": "Text" + } + ], + "parameters": { + "Right": 4 + }, + "nullable": [ + false + ] + }, + "hash": "8b474c8ed7859f745888644db639b7a4a21210ebf0b7fba97cd016ff6ab4d769" +} diff --git a/.sqlx/query-e0deb4dfaffe4527ad630c884970c3eb93fc734d979e6b8e78cd7d0b6dd0b669.json b/.sqlx/query-e0deb4dfaffe4527ad630c884970c3eb93fc734d979e6b8e78cd7d0b6dd0b669.json deleted file mode 100644 index eda6697..0000000 --- a/.sqlx/query-e0deb4dfaffe4527ad630c884970c3eb93fc734d979e6b8e78cd7d0b6dd0b669.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n insert\n into token (id, secret, login, issued_at, last_used_at)\n values ($1, $2, $3, $4, $4)\n returning secret as \"secret!: IdentitySecret\"\n ", - "describe": { - "columns": [ - { - "name": "secret!: IdentitySecret", - "ordinal": 0, - "type_info": "Text" - } - ], - "parameters": { - "Right": 4 - }, - "nullable": [ - false - ] - }, - "hash": "e0deb4dfaffe4527ad630c884970c3eb93fc734d979e6b8e78cd7d0b6dd0b669" -} diff --git a/src/app.rs b/src/app.rs index c13f52f..84a6357 100644 --- a/src/app.rs +++ b/src/app.rs @@ -2,7 +2,7 @@ use sqlx::sqlite::SqlitePool; use crate::{ channel::app::Channels, - events::{app::Events, broadcaster::Broadcaster as EventBroadcaster}, + event::{app::Events, broadcaster::Broadcaster as EventBroadcaster}, login::{app::Logins, broadcaster::Broadcaster as LoginBroadcaster}, }; diff --git a/src/channel/app.rs b/src/channel/app.rs index d89e733..1422651 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -2,12 +2,11 @@ use chrono::TimeDelta; use sqlx::sqlite::SqlitePool; use crate::{ + channel::Channel, clock::DateTime, - events::{broadcaster::Broadcaster, types::ChannelEvent}, - repo::{ - channel::{Channel, Provider as _}, - sequence::{Provider as _, Sequence}, - }, + event::Sequence, + event::{broadcaster::Broadcaster, types::ChannelEvent}, + repo::{channel::Provider as _, sequence::Provider as _}, }; pub struct Channels<'a> { diff --git a/src/channel/mod.rs b/src/channel/mod.rs index 3115e98..02d0ed4 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -1,7 +1,17 @@ +use crate::{clock::DateTime, event::Sequence}; + pub mod app; mod id; mod routes; -pub use self::routes::router; +pub use self::{id::Id, routes::router}; -pub use self::id::Id; +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Channel { + pub id: Id, + pub name: String, + #[serde(skip)] + pub created_at: DateTime, + #[serde(skip)] + pub created_sequence: Sequence, +} diff --git a/src/channel/routes.rs b/src/channel/routes.rs index 72d6195..5d8b61e 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -10,11 +10,11 @@ use axum_extra::extract::Query; use super::app; use crate::{ app::App, - channel, + channel::{self, Channel}, clock::RequestedAt, error::Internal, - events::app::EventsError, - repo::{channel::Channel, login::Login, sequence::Sequence}, + event::{app::EventsError, Sequence}, + login::Login, }; #[cfg(test)] diff --git a/src/channel/routes/test/on_create.rs b/src/channel/routes/test/on_create.rs index 72980ac..9988932 100644 --- a/src/channel/routes/test/on_create.rs +++ b/src/channel/routes/test/on_create.rs @@ -3,7 +3,7 @@ use futures::stream::StreamExt as _; use crate::{ channel::{app, routes}, - events::types, + event::types, test::fixtures::{self, future::Immediately as _}, }; diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index 987784d..6f844cd 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -4,7 +4,7 @@ use futures::stream::StreamExt; use crate::{ channel, channel::routes, - events::{app, types}, + event::{app, types}, test::fixtures::{self, future::Immediately as _}, }; diff --git a/src/cli.rs b/src/cli.rs index 132baf8..ee95ea6 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -10,7 +10,7 @@ use clap::Parser; use sqlx::sqlite::SqlitePool; use tokio::net; -use crate::{app::App, channel, clock, events, expire, login, repo::pool}; +use crate::{app::App, channel, clock, event, expire, login, repo::pool}; /// Command-line entry point for running the `hi` server. /// @@ -105,7 +105,7 @@ impl Args { } fn routers() -> Router { - [channel::router(), events::router(), login::router()] + [channel::router(), event::router(), login::router()] .into_iter() .fold(Router::default(), Router::merge) } diff --git a/src/event/app.rs b/src/event/app.rs new file mode 100644 index 0000000..b5f2ecc --- /dev/null +++ b/src/event/app.rs @@ -0,0 +1,137 @@ +use chrono::TimeDelta; +use futures::{ + future, + stream::{self, StreamExt as _}, + Stream, +}; +use sqlx::sqlite::SqlitePool; + +use super::{ + broadcaster::Broadcaster, + repo::message::Provider as _, + types::{self, ChannelEvent}, +}; +use crate::{ + channel, + clock::DateTime, + event::Sequence, + login::Login, + repo::{channel::Provider as _, error::NotFound as _, sequence::Provider as _}, +}; + +pub struct Events<'a> { + db: &'a SqlitePool, + events: &'a Broadcaster, +} + +impl<'a> Events<'a> { + pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { + Self { db, events } + } + + pub async fn send( + &self, + login: &Login, + channel: &channel::Id, + body: &str, + sent_at: &DateTime, + ) -> Result { + let mut tx = self.db.begin().await?; + let channel = tx + .channels() + .by_id(channel) + .await + .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; + let sent_sequence = tx.sequence().next().await?; + let event = tx + .message_events() + .create(login, &channel, sent_at, sent_sequence, body) + .await?; + tx.commit().await?; + + self.events.broadcast(&event); + Ok(event) + } + + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { + // Somewhat arbitrarily, expire after 90 days. + let expire_at = relative_to.to_owned() - TimeDelta::days(90); + + let mut tx = self.db.begin().await?; + let expired = tx.message_events().expired(&expire_at).await?; + + let mut events = Vec::with_capacity(expired.len()); + for (channel, message) in expired { + let deleted_sequence = tx.sequence().next().await?; + let event = tx + .message_events() + .delete(&channel, &message, relative_to, deleted_sequence) + .await?; + events.push(event); + } + + tx.commit().await?; + + for event in events { + self.events.broadcast(&event); + } + + Ok(()) + } + + pub async fn subscribe( + &self, + resume_at: Option, + ) -> Result + std::fmt::Debug, sqlx::Error> { + // Subscribe before retrieving, to catch messages broadcast while we're + // querying the DB. We'll prune out duplicates later. + let live_messages = self.events.subscribe(); + + let mut tx = self.db.begin().await?; + let channels = tx.channels().replay(resume_at).await?; + + let channel_events = channels + .into_iter() + .map(ChannelEvent::created) + .filter(move |event| resume_at.map_or(true, |resume_at| event.sequence > resume_at)); + + let message_events = tx.message_events().replay(resume_at).await?; + + let mut replay_events = channel_events + .into_iter() + .chain(message_events.into_iter()) + .collect::>(); + replay_events.sort_by_key(|event| event.sequence); + let resume_live_at = replay_events.last().map(|event| event.sequence); + + let replay = stream::iter(replay_events); + + // no skip_expired or resume transforms for stored_messages, as it's + // constructed not to contain messages meeting either criterion. + // + // * skip_expired is redundant with the `tx.broadcasts().expire(…)` call; + // * resume is redundant with the resume_at argument to + // `tx.broadcasts().replay(…)`. + let live_messages = live_messages + // Filtering on the broadcast resume point filters out messages + // before resume_at, and filters out messages duplicated from + // stored_messages. + .filter(Self::resume(resume_live_at)); + + Ok(replay.chain(live_messages)) + } + + fn resume( + resume_at: Option, + ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready { + move |event| future::ready(resume_at < Some(event.sequence)) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum EventsError { + #[error("channel {0} not found")] + ChannelNotFound(channel::Id), + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), +} diff --git a/src/event/broadcaster.rs b/src/event/broadcaster.rs new file mode 100644 index 0000000..92f631f --- /dev/null +++ b/src/event/broadcaster.rs @@ -0,0 +1,3 @@ +use crate::{broadcast, event::types}; + +pub type Broadcaster = broadcast::Broadcaster; diff --git a/src/event/extract.rs b/src/event/extract.rs new file mode 100644 index 0000000..e3021e2 --- /dev/null +++ b/src/event/extract.rs @@ -0,0 +1,85 @@ +use std::ops::Deref; + +use axum::{ + extract::FromRequestParts, + http::{request::Parts, HeaderName, HeaderValue}, +}; +use axum_extra::typed_header::TypedHeader; +use serde::{de::DeserializeOwned, Serialize}; + +// A typed header. When used as a bare extractor, reads from the +// `Last-Event-Id` HTTP header. +pub struct LastEventId(pub T); + +static LAST_EVENT_ID: HeaderName = HeaderName::from_static("last-event-id"); + +impl headers::Header for LastEventId +where + T: Serialize + DeserializeOwned, +{ + fn name() -> &'static HeaderName { + &LAST_EVENT_ID + } + + fn decode<'i, I>(values: &mut I) -> Result + where + I: Iterator, + { + let value = values.next().ok_or_else(headers::Error::invalid)?; + let value = value.to_str().map_err(|_| headers::Error::invalid())?; + let value = serde_json::from_str(value).map_err(|_| headers::Error::invalid())?; + Ok(Self(value)) + } + + fn encode(&self, values: &mut E) + where + E: Extend, + { + let Self(value) = self; + // Must panic or suppress; the trait provides no other options. + let value = serde_json::to_string(value).expect("value can be encoded as JSON"); + let value = HeaderValue::from_str(&value).expect("LastEventId is a valid header value"); + + values.extend(std::iter::once(value)); + } +} + +#[async_trait::async_trait] +impl FromRequestParts for LastEventId +where + S: Send + Sync, + T: Serialize + DeserializeOwned, +{ + type Rejection = as FromRequestParts>::Rejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + // This is purely for ergonomics: it allows `RequestedAt` to be extracted + // without having to wrap it in `Extension<>`. Callers _can_ still do that, + // but they aren't forced to. + let TypedHeader(requested_at) = TypedHeader::from_request_parts(parts, state).await?; + + Ok(requested_at) + } +} + +impl Deref for LastEventId { + type Target = T; + + fn deref(&self) -> &Self::Target { + let Self(header) = self; + header + } +} + +impl From for LastEventId { + fn from(value: T) -> Self { + Self(value) + } +} + +impl LastEventId { + pub fn into_inner(self) -> T { + let Self(value) = self; + value + } +} diff --git a/src/event/mod.rs b/src/event/mod.rs new file mode 100644 index 0000000..7ad3f9c --- /dev/null +++ b/src/event/mod.rs @@ -0,0 +1,9 @@ +pub mod app; +pub mod broadcaster; +mod extract; +pub mod repo; +mod routes; +mod sequence; +pub mod types; + +pub use self::{routes::router, sequence::Sequence}; diff --git a/src/event/repo/message.rs b/src/event/repo/message.rs new file mode 100644 index 0000000..f051fec --- /dev/null +++ b/src/event/repo/message.rs @@ -0,0 +1,188 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::{ + channel::{self, Channel}, + clock::DateTime, + event::{types, Sequence}, + login::{self, Login}, + message::{self, Message}, +}; + +pub trait Provider { + fn message_events(&mut self) -> Events; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn message_events(&mut self) -> Events { + Events(self) + } +} + +pub struct Events<'t>(&'t mut SqliteConnection); + +impl<'c> Events<'c> { + pub async fn create( + &mut self, + sender: &Login, + channel: &Channel, + sent_at: &DateTime, + sent_sequence: Sequence, + body: &str, + ) -> Result { + let id = message::Id::generate(); + + let message = sqlx::query!( + r#" + insert into message + (id, channel, sender, sent_at, sent_sequence, body) + values ($1, $2, $3, $4, $5, $6) + returning + id as "id: message::Id", + sender as "sender: login::Id", + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence", + body + "#, + id, + channel.id, + sender.id, + sent_at, + sent_sequence, + body, + ) + .map(|row| types::ChannelEvent { + sequence: row.sent_sequence, + at: row.sent_at, + data: types::MessageEvent { + channel: channel.clone(), + sender: sender.clone(), + message: Message { + id: row.id, + body: row.body, + }, + } + .into(), + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + pub async fn delete( + &mut self, + channel: &Channel, + message: &message::Id, + deleted_at: &DateTime, + deleted_sequence: Sequence, + ) -> Result { + sqlx::query_scalar!( + r#" + delete from message + where id = $1 + returning 1 as "row: i64" + "#, + message, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(types::ChannelEvent { + sequence: deleted_sequence, + at: *deleted_at, + data: types::MessageDeletedEvent { + channel: channel.clone(), + message: message.clone(), + } + .into(), + }) + } + + pub async fn expired( + &mut self, + expire_at: &DateTime, + ) -> Result, sqlx::Error> { + let messages = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + channel.created_at as "channel_created_at: DateTime", + channel.created_sequence as "channel_created_sequence: Sequence", + message.id as "message: message::Id" + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where sent_at < $1 + "#, + expire_at, + ) + .map(|row| { + ( + Channel { + id: row.channel_id, + name: row.channel_name, + created_at: row.channel_created_at, + created_sequence: row.channel_created_sequence, + }, + row.message, + ) + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } + + pub async fn replay( + &mut self, + resume_at: Option, + ) -> Result, sqlx::Error> { + let events = sqlx::query!( + r#" + select + message.id as "id: message::Id", + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + channel.created_at as "channel_created_at: DateTime", + channel.created_sequence as "channel_created_sequence: Sequence", + sender.id as "sender_id: login::Id", + sender.name as sender_name, + message.sent_at as "sent_at: DateTime", + message.sent_sequence as "sent_sequence: Sequence", + message.body + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where coalesce(message.sent_sequence > $1, true) + order by sent_sequence asc + "#, + resume_at, + ) + .map(|row| types::ChannelEvent { + sequence: row.sent_sequence, + at: row.sent_at, + data: types::MessageEvent { + channel: Channel { + id: row.channel_id, + name: row.channel_name, + created_at: row.channel_created_at, + created_sequence: row.channel_created_sequence, + }, + sender: Login { + id: row.sender_id, + name: row.sender_name, + }, + message: Message { + id: row.id, + body: row.body, + }, + } + .into(), + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(events) + } +} diff --git a/src/event/repo/mod.rs b/src/event/repo/mod.rs new file mode 100644 index 0000000..e216a50 --- /dev/null +++ b/src/event/repo/mod.rs @@ -0,0 +1 @@ +pub mod message; diff --git a/src/event/routes.rs b/src/event/routes.rs new file mode 100644 index 0000000..77761ca --- /dev/null +++ b/src/event/routes.rs @@ -0,0 +1,93 @@ +use axum::{ + extract::State, + response::{ + sse::{self, Sse}, + IntoResponse, Response, + }, + routing::get, + Router, +}; +use axum_extra::extract::Query; +use futures::stream::{Stream, StreamExt as _}; + +use super::{extract::LastEventId, types}; +use crate::{ + app::App, + error::{Internal, Unauthorized}, + event::Sequence, + login::app::ValidateError, + token::extract::Identity, +}; + +#[cfg(test)] +mod test; + +pub fn router() -> Router { + Router::new().route("/api/events", get(events)) +} + +#[derive(Default, serde::Deserialize)] +struct EventsQuery { + resume_point: Option, +} + +async fn events( + State(app): State, + identity: Identity, + last_event_id: Option>, + Query(query): Query, +) -> Result + std::fmt::Debug>, EventsError> { + let resume_at = last_event_id + .map(LastEventId::into_inner) + .or(query.resume_point); + + let stream = app.events().subscribe(resume_at).await?; + let stream = app.logins().limit_stream(identity.token, stream).await?; + + Ok(Events(stream)) +} + +#[derive(Debug)] +struct Events(S); + +impl IntoResponse for Events +where + S: Stream + Send + 'static, +{ + fn into_response(self) -> Response { + let Self(stream) = self; + let stream = stream.map(sse::Event::try_from); + Sse::new(stream) + .keep_alive(sse::KeepAlive::default()) + .into_response() + } +} + +impl TryFrom for sse::Event { + type Error = serde_json::Error; + + fn try_from(event: types::ChannelEvent) -> Result { + let id = serde_json::to_string(&event.sequence)?; + let data = serde_json::to_string_pretty(&event)?; + + let event = Self::default().id(id).data(data); + + Ok(event) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum EventsError { + DatabaseError(#[from] sqlx::Error), + ValidateError(#[from] ValidateError), +} + +impl IntoResponse for EventsError { + fn into_response(self) -> Response { + match self { + Self::ValidateError(ValidateError::InvalidToken) => Unauthorized.into_response(), + other => Internal::from(other).into_response(), + } + } +} diff --git a/src/event/routes/test.rs b/src/event/routes/test.rs new file mode 100644 index 0000000..9a3b12a --- /dev/null +++ b/src/event/routes/test.rs @@ -0,0 +1,439 @@ +use axum::extract::State; +use axum_extra::extract::Query; +use futures::{ + future, + stream::{self, StreamExt as _}, +}; + +use crate::{ + event::routes, + test::fixtures::{self, future::Immediately as _}, +}; + +#[tokio::test] +async fn includes_historical_message() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let sender = fixtures::login::create(&app).await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the structure of the response. + + let event = events + .filter(fixtures::filter::messages()) + .next() + .immediately() + .await + .expect("delivered stored message"); + + assert_eq!(message, event); +} + +#[tokio::test] +async fn includes_live_message() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + let routes::Events(events) = + routes::events(State(app.clone()), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the semantics + + let sender = fixtures::login::create(&app).await; + let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + + let event = events + .filter(fixtures::filter::messages()) + .next() + .immediately() + .await + .expect("delivered live message"); + + assert_eq!(message, event); +} + +#[tokio::test] +async fn includes_multiple_channels() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let sender = fixtures::login::create(&app).await; + + let channels = [ + fixtures::channel::create(&app, &fixtures::now()).await, + fixtures::channel::create(&app, &fixtures::now()).await, + ]; + + let messages = stream::iter(channels) + .then(|channel| { + let app = app.clone(); + let sender = sender.clone(); + let channel = channel.clone(); + async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await } + }) + .collect::>() + .await; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the structure of the response. + + let events = events + .filter(fixtures::filter::messages()) + .take(messages.len()) + .collect::>() + .immediately() + .await; + + for message in &messages { + assert!(events.iter().any(|event| { event == message })); + } +} + +#[tokio::test] +async fn sequential_messages() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + let messages = vec![ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the structure of the response. + + let mut events = events.filter(|event| future::ready(messages.contains(event))); + + // Verify delivery in order + for message in &messages { + let event = events + .next() + .immediately() + .await + .expect("undelivered messages remaining"); + + assert_eq!(message, &event); + } +} + +#[tokio::test] +async fn resumes_from() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + let initial_message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + + let later_messages = vec![ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + + let resume_at = { + // First subscription + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); + + let event = events + .filter(fixtures::filter::messages()) + .next() + .immediately() + .await + .expect("delivered events"); + + assert_eq!(initial_message, event); + + event.sequence + }; + + // Resume after disconnect + let routes::Events(resumed) = routes::events( + State(app), + subscriber, + Some(resume_at.into()), + Query::default(), + ) + .await + .expect("subscribe never fails"); + + // Verify the structure of the response. + + let events = resumed + .take(later_messages.len()) + .collect::>() + .immediately() + .await; + + for message in &later_messages { + assert!(events.iter().any(|event| event == message)); + } +} + +// This test verifies a real bug I hit developing the vector-of-sequences +// approach to resuming events. A small omission caused the event IDs in a +// resumed stream to _omit_ channels that were in the original stream until +// those channels also appeared in the resumed stream. +// +// Clients would see something like +// * In the original stream, Cfoo=5,Cbar=8 +// * In the resumed stream, Cfoo=6 (no Cbar sequence number) +// +// Disconnecting and reconnecting a second time, using event IDs from that +// initial period of the first resume attempt, would then cause the second +// resume attempt to restart all other channels from the beginning, and not +// from where the first disconnection happened. +// +// This is a real and valid behaviour for clients! +#[tokio::test] +async fn serial_resume() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let sender = fixtures::login::create(&app).await; + let channel_a = fixtures::channel::create(&app, &fixtures::now()).await; + let channel_b = fixtures::channel::create(&app, &fixtures::now()).await; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + + let resume_at = { + let initial_messages = [ + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + ]; + + // First subscription + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); + + let events = events + .filter(fixtures::filter::messages()) + .take(initial_messages.len()) + .collect::>() + .immediately() + .await; + + for message in &initial_messages { + assert!(events.iter().any(|event| event == message)); + } + + let event = events.last().expect("this vec is non-empty"); + + event.sequence + }; + + // Resume after disconnect + let resume_at = { + let resume_messages = [ + // Note that channel_b does not appear here. The buggy behaviour + // would be masked if channel_b happened to send a new message + // into the resumed event stream. + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + ]; + + // Second subscription + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + Some(resume_at.into()), + Query::default(), + ) + .await + .expect("subscribe never fails"); + + let events = events + .filter(fixtures::filter::messages()) + .take(resume_messages.len()) + .collect::>() + .immediately() + .await; + + for message in &resume_messages { + assert!(events.iter().any(|event| event == message)); + } + + let event = events.last().expect("this vec is non-empty"); + + event.sequence + }; + + // Resume after disconnect a second time + { + // At this point, we can send on either channel and demonstrate the + // problem. The resume point should before both of these messages, but + // after _all_ prior messages. + let final_messages = [ + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + ]; + + // Third subscription + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + Some(resume_at.into()), + Query::default(), + ) + .await + .expect("subscribe never fails"); + + let events = events + .filter(fixtures::filter::messages()) + .take(final_messages.len()) + .collect::>() + .immediately() + .await; + + // This set of messages, in particular, _should not_ include any prior + // messages from `initial_messages` or `resume_messages`. + for message in &final_messages { + assert!(events.iter().any(|event| event == message)); + } + }; +} + +#[tokio::test] +async fn terminates_on_token_expiry() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + // Subscribe via the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = + fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await; + + let routes::Events(events) = + routes::events(State(app.clone()), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the resulting stream's behaviour + + app.logins() + .expire(&fixtures::now()) + .await + .expect("expiring tokens succeeds"); + + // These should not be delivered. + let messages = [ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + assert!(events + .filter(|event| future::ready(messages.contains(event))) + .next() + .immediately() + .await + .is_none()); +} + +#[tokio::test] +async fn terminates_on_logout() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + // Subscribe via the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber_token = + fixtures::identity::logged_in(&app, &subscriber_creds, &fixtures::now()).await; + let subscriber = + fixtures::identity::from_token(&app, &subscriber_token, &fixtures::now()).await; + + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); + + // Verify the resulting stream's behaviour + + app.logins() + .logout(&subscriber.token) + .await + .expect("expiring tokens succeeds"); + + // These should not be delivered. + let messages = [ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + assert!(events + .filter(|event| future::ready(messages.contains(event))) + .next() + .immediately() + .await + .is_none()); +} diff --git a/src/event/sequence.rs b/src/event/sequence.rs new file mode 100644 index 0000000..9ebddd7 --- /dev/null +++ b/src/event/sequence.rs @@ -0,0 +1,24 @@ +use std::fmt; + +#[derive( + Clone, + Copy, + Debug, + Eq, + Ord, + PartialEq, + PartialOrd, + serde::Deserialize, + serde::Serialize, + sqlx::Type, +)] +#[serde(transparent)] +#[sqlx(transparent)] +pub struct Sequence(i64); + +impl fmt::Display for Sequence { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self(value) = self; + value.fmt(f) + } +} diff --git a/src/event/types.rs b/src/event/types.rs new file mode 100644 index 0000000..cd7dea6 --- /dev/null +++ b/src/event/types.rs @@ -0,0 +1,97 @@ +use crate::{ + channel::{self, Channel}, + clock::DateTime, + event::Sequence, + login::Login, + message::{self, Message}, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct ChannelEvent { + #[serde(skip)] + pub sequence: Sequence, + pub at: DateTime, + #[serde(flatten)] + pub data: ChannelEventData, +} + +impl ChannelEvent { + pub fn created(channel: Channel) -> Self { + Self { + at: channel.created_at, + sequence: channel.created_sequence, + data: CreatedEvent { channel }.into(), + } + } + + pub fn channel_id(&self) -> &channel::Id { + match &self.data { + ChannelEventData::Created(event) => &event.channel.id, + ChannelEventData::Message(event) => &event.channel.id, + ChannelEventData::MessageDeleted(event) => &event.channel.id, + ChannelEventData::Deleted(event) => &event.channel, + } + } +} + +impl<'c> From<&'c ChannelEvent> for Sequence { + fn from(event: &'c ChannelEvent) -> Self { + event.sequence + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ChannelEventData { + Created(CreatedEvent), + Message(MessageEvent), + MessageDeleted(MessageDeletedEvent), + Deleted(DeletedEvent), +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct CreatedEvent { + pub channel: Channel, +} + +impl From for ChannelEventData { + fn from(event: CreatedEvent) -> Self { + Self::Created(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct MessageEvent { + pub channel: Channel, + pub sender: Login, + pub message: Message, +} + +impl From for ChannelEventData { + fn from(event: MessageEvent) -> Self { + Self::Message(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct MessageDeletedEvent { + pub channel: Channel, + pub message: message::Id, +} + +impl From for ChannelEventData { + fn from(event: MessageDeletedEvent) -> Self { + Self::MessageDeleted(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct DeletedEvent { + pub channel: channel::Id, +} + +impl From for ChannelEventData { + fn from(event: DeletedEvent) -> Self { + Self::Deleted(event) + } +} diff --git a/src/events/app.rs b/src/events/app.rs deleted file mode 100644 index 1fa2f70..0000000 --- a/src/events/app.rs +++ /dev/null @@ -1,140 +0,0 @@ -use chrono::TimeDelta; -use futures::{ - future, - stream::{self, StreamExt as _}, - Stream, -}; -use sqlx::sqlite::SqlitePool; - -use super::{ - broadcaster::Broadcaster, - repo::message::Provider as _, - types::{self, ChannelEvent}, -}; -use crate::{ - channel, - clock::DateTime, - repo::{ - channel::Provider as _, - error::NotFound as _, - login::Login, - sequence::{Provider as _, Sequence}, - }, -}; - -pub struct Events<'a> { - db: &'a SqlitePool, - events: &'a Broadcaster, -} - -impl<'a> Events<'a> { - pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { - Self { db, events } - } - - pub async fn send( - &self, - login: &Login, - channel: &channel::Id, - body: &str, - sent_at: &DateTime, - ) -> Result { - let mut tx = self.db.begin().await?; - let channel = tx - .channels() - .by_id(channel) - .await - .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; - let sent_sequence = tx.sequence().next().await?; - let event = tx - .message_events() - .create(login, &channel, sent_at, sent_sequence, body) - .await?; - tx.commit().await?; - - self.events.broadcast(&event); - Ok(event) - } - - pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { - // Somewhat arbitrarily, expire after 90 days. - let expire_at = relative_to.to_owned() - TimeDelta::days(90); - - let mut tx = self.db.begin().await?; - let expired = tx.message_events().expired(&expire_at).await?; - - let mut events = Vec::with_capacity(expired.len()); - for (channel, message) in expired { - let deleted_sequence = tx.sequence().next().await?; - let event = tx - .message_events() - .delete(&channel, &message, relative_to, deleted_sequence) - .await?; - events.push(event); - } - - tx.commit().await?; - - for event in events { - self.events.broadcast(&event); - } - - Ok(()) - } - - pub async fn subscribe( - &self, - resume_at: Option, - ) -> Result + std::fmt::Debug, sqlx::Error> { - // Subscribe before retrieving, to catch messages broadcast while we're - // querying the DB. We'll prune out duplicates later. - let live_messages = self.events.subscribe(); - - let mut tx = self.db.begin().await?; - let channels = tx.channels().replay(resume_at).await?; - - let channel_events = channels - .into_iter() - .map(ChannelEvent::created) - .filter(move |event| resume_at.map_or(true, |resume_at| event.sequence > resume_at)); - - let message_events = tx.message_events().replay(resume_at).await?; - - let mut replay_events = channel_events - .into_iter() - .chain(message_events.into_iter()) - .collect::>(); - replay_events.sort_by_key(|event| event.sequence); - let resume_live_at = replay_events.last().map(|event| event.sequence); - - let replay = stream::iter(replay_events); - - // no skip_expired or resume transforms for stored_messages, as it's - // constructed not to contain messages meeting either criterion. - // - // * skip_expired is redundant with the `tx.broadcasts().expire(…)` call; - // * resume is redundant with the resume_at argument to - // `tx.broadcasts().replay(…)`. - let live_messages = live_messages - // Filtering on the broadcast resume point filters out messages - // before resume_at, and filters out messages duplicated from - // stored_messages. - .filter(Self::resume(resume_live_at)); - - Ok(replay.chain(live_messages)) - } - - fn resume( - resume_at: Option, - ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready { - move |event| future::ready(resume_at < Some(event.sequence)) - } -} - -#[derive(Debug, thiserror::Error)] -pub enum EventsError { - #[error("channel {0} not found")] - ChannelNotFound(channel::Id), - #[error(transparent)] - DatabaseError(#[from] sqlx::Error), -} diff --git a/src/events/broadcaster.rs b/src/events/broadcaster.rs deleted file mode 100644 index 6b664cb..0000000 --- a/src/events/broadcaster.rs +++ /dev/null @@ -1,3 +0,0 @@ -use crate::{broadcast, events::types}; - -pub type Broadcaster = broadcast::Broadcaster; diff --git a/src/events/extract.rs b/src/events/extract.rs deleted file mode 100644 index e3021e2..0000000 --- a/src/events/extract.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::ops::Deref; - -use axum::{ - extract::FromRequestParts, - http::{request::Parts, HeaderName, HeaderValue}, -}; -use axum_extra::typed_header::TypedHeader; -use serde::{de::DeserializeOwned, Serialize}; - -// A typed header. When used as a bare extractor, reads from the -// `Last-Event-Id` HTTP header. -pub struct LastEventId(pub T); - -static LAST_EVENT_ID: HeaderName = HeaderName::from_static("last-event-id"); - -impl headers::Header for LastEventId -where - T: Serialize + DeserializeOwned, -{ - fn name() -> &'static HeaderName { - &LAST_EVENT_ID - } - - fn decode<'i, I>(values: &mut I) -> Result - where - I: Iterator, - { - let value = values.next().ok_or_else(headers::Error::invalid)?; - let value = value.to_str().map_err(|_| headers::Error::invalid())?; - let value = serde_json::from_str(value).map_err(|_| headers::Error::invalid())?; - Ok(Self(value)) - } - - fn encode(&self, values: &mut E) - where - E: Extend, - { - let Self(value) = self; - // Must panic or suppress; the trait provides no other options. - let value = serde_json::to_string(value).expect("value can be encoded as JSON"); - let value = HeaderValue::from_str(&value).expect("LastEventId is a valid header value"); - - values.extend(std::iter::once(value)); - } -} - -#[async_trait::async_trait] -impl FromRequestParts for LastEventId -where - S: Send + Sync, - T: Serialize + DeserializeOwned, -{ - type Rejection = as FromRequestParts>::Rejection; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - // This is purely for ergonomics: it allows `RequestedAt` to be extracted - // without having to wrap it in `Extension<>`. Callers _can_ still do that, - // but they aren't forced to. - let TypedHeader(requested_at) = TypedHeader::from_request_parts(parts, state).await?; - - Ok(requested_at) - } -} - -impl Deref for LastEventId { - type Target = T; - - fn deref(&self) -> &Self::Target { - let Self(header) = self; - header - } -} - -impl From for LastEventId { - fn from(value: T) -> Self { - Self(value) - } -} - -impl LastEventId { - pub fn into_inner(self) -> T { - let Self(value) = self; - value - } -} diff --git a/src/events/mod.rs b/src/events/mod.rs deleted file mode 100644 index 711ae64..0000000 --- a/src/events/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub mod app; -pub mod broadcaster; -mod extract; -pub mod repo; -mod routes; -pub mod types; - -pub use self::routes::router; diff --git a/src/events/repo/message.rs b/src/events/repo/message.rs deleted file mode 100644 index 00c24b1..0000000 --- a/src/events/repo/message.rs +++ /dev/null @@ -1,188 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{ - channel, - clock::DateTime, - events::types, - login, message, - repo::{channel::Channel, login::Login, message::Message, sequence::Sequence}, -}; - -pub trait Provider { - fn message_events(&mut self) -> Events; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn message_events(&mut self) -> Events { - Events(self) - } -} - -pub struct Events<'t>(&'t mut SqliteConnection); - -impl<'c> Events<'c> { - pub async fn create( - &mut self, - sender: &Login, - channel: &Channel, - sent_at: &DateTime, - sent_sequence: Sequence, - body: &str, - ) -> Result { - let id = message::Id::generate(); - - let message = sqlx::query!( - r#" - insert into message - (id, channel, sender, sent_at, sent_sequence, body) - values ($1, $2, $3, $4, $5, $6) - returning - id as "id: message::Id", - sender as "sender: login::Id", - sent_at as "sent_at: DateTime", - sent_sequence as "sent_sequence: Sequence", - body - "#, - id, - channel.id, - sender.id, - sent_at, - sent_sequence, - body, - ) - .map(|row| types::ChannelEvent { - sequence: row.sent_sequence, - at: row.sent_at, - data: types::MessageEvent { - channel: channel.clone(), - sender: sender.clone(), - message: Message { - id: row.id, - body: row.body, - }, - } - .into(), - }) - .fetch_one(&mut *self.0) - .await?; - - Ok(message) - } - - pub async fn delete( - &mut self, - channel: &Channel, - message: &message::Id, - deleted_at: &DateTime, - deleted_sequence: Sequence, - ) -> Result { - sqlx::query_scalar!( - r#" - delete from message - where id = $1 - returning 1 as "row: i64" - "#, - message, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(types::ChannelEvent { - sequence: deleted_sequence, - at: *deleted_at, - data: types::MessageDeletedEvent { - channel: channel.clone(), - message: message.clone(), - } - .into(), - }) - } - - pub async fn expired( - &mut self, - expire_at: &DateTime, - ) -> Result, sqlx::Error> { - let messages = sqlx::query!( - r#" - select - channel.id as "channel_id: channel::Id", - channel.name as "channel_name", - channel.created_at as "channel_created_at: DateTime", - channel.created_sequence as "channel_created_sequence: Sequence", - message.id as "message: message::Id" - from message - join channel on message.channel = channel.id - join login as sender on message.sender = sender.id - where sent_at < $1 - "#, - expire_at, - ) - .map(|row| { - ( - Channel { - id: row.channel_id, - name: row.channel_name, - created_at: row.channel_created_at, - created_sequence: row.channel_created_sequence, - }, - row.message, - ) - }) - .fetch_all(&mut *self.0) - .await?; - - Ok(messages) - } - - pub async fn replay( - &mut self, - resume_at: Option, - ) -> Result, sqlx::Error> { - let events = sqlx::query!( - r#" - select - message.id as "id: message::Id", - channel.id as "channel_id: channel::Id", - channel.name as "channel_name", - channel.created_at as "channel_created_at: DateTime", - channel.created_sequence as "channel_created_sequence: Sequence", - sender.id as "sender_id: login::Id", - sender.name as sender_name, - message.sent_at as "sent_at: DateTime", - message.sent_sequence as "sent_sequence: Sequence", - message.body - from message - join channel on message.channel = channel.id - join login as sender on message.sender = sender.id - where coalesce(message.sent_sequence > $1, true) - order by sent_sequence asc - "#, - resume_at, - ) - .map(|row| types::ChannelEvent { - sequence: row.sent_sequence, - at: row.sent_at, - data: types::MessageEvent { - channel: Channel { - id: row.channel_id, - name: row.channel_name, - created_at: row.channel_created_at, - created_sequence: row.channel_created_sequence, - }, - sender: Login { - id: row.sender_id, - name: row.sender_name, - }, - message: Message { - id: row.id, - body: row.body, - }, - } - .into(), - }) - .fetch_all(&mut *self.0) - .await?; - - Ok(events) - } -} diff --git a/src/events/repo/mod.rs b/src/events/repo/mod.rs deleted file mode 100644 index e216a50..0000000 --- a/src/events/repo/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod message; diff --git a/src/events/routes.rs b/src/events/routes.rs deleted file mode 100644 index d81c7fb..0000000 --- a/src/events/routes.rs +++ /dev/null @@ -1,92 +0,0 @@ -use axum::{ - extract::State, - response::{ - sse::{self, Sse}, - IntoResponse, Response, - }, - routing::get, - Router, -}; -use axum_extra::extract::Query; -use futures::stream::{Stream, StreamExt as _}; - -use super::{extract::LastEventId, types}; -use crate::{ - app::App, - error::{Internal, Unauthorized}, - login::{app::ValidateError, extract::Identity}, - repo::sequence::Sequence, -}; - -#[cfg(test)] -mod test; - -pub fn router() -> Router { - Router::new().route("/api/events", get(events)) -} - -#[derive(Default, serde::Deserialize)] -struct EventsQuery { - resume_point: Option, -} - -async fn events( - State(app): State, - identity: Identity, - last_event_id: Option>, - Query(query): Query, -) -> Result + std::fmt::Debug>, EventsError> { - let resume_at = last_event_id - .map(LastEventId::into_inner) - .or(query.resume_point); - - let stream = app.events().subscribe(resume_at).await?; - let stream = app.logins().limit_stream(identity.token, stream).await?; - - Ok(Events(stream)) -} - -#[derive(Debug)] -struct Events(S); - -impl IntoResponse for Events -where - S: Stream + Send + 'static, -{ - fn into_response(self) -> Response { - let Self(stream) = self; - let stream = stream.map(sse::Event::try_from); - Sse::new(stream) - .keep_alive(sse::KeepAlive::default()) - .into_response() - } -} - -impl TryFrom for sse::Event { - type Error = serde_json::Error; - - fn try_from(event: types::ChannelEvent) -> Result { - let id = serde_json::to_string(&event.sequence)?; - let data = serde_json::to_string_pretty(&event)?; - - let event = Self::default().id(id).data(data); - - Ok(event) - } -} - -#[derive(Debug, thiserror::Error)] -#[error(transparent)] -pub enum EventsError { - DatabaseError(#[from] sqlx::Error), - ValidateError(#[from] ValidateError), -} - -impl IntoResponse for EventsError { - fn into_response(self) -> Response { - match self { - Self::ValidateError(ValidateError::InvalidToken) => Unauthorized.into_response(), - other => Internal::from(other).into_response(), - } - } -} diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs deleted file mode 100644 index 11f01b8..0000000 --- a/src/events/routes/test.rs +++ /dev/null @@ -1,439 +0,0 @@ -use axum::extract::State; -use axum_extra::extract::Query; -use futures::{ - future, - stream::{self, StreamExt as _}, -}; - -use crate::{ - events::routes, - test::fixtures::{self, future::Immediately as _}, -}; - -#[tokio::test] -async fn includes_historical_message() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let sender = fixtures::login::create(&app).await; - let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - - // Call the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); - - // Verify the structure of the response. - - let event = events - .filter(fixtures::filter::messages()) - .next() - .immediately() - .await - .expect("delivered stored message"); - - assert_eq!(message, event); -} - -#[tokio::test] -async fn includes_live_message() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app, &fixtures::now()).await; - - // Call the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = - routes::events(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); - - // Verify the semantics - - let sender = fixtures::login::create(&app).await; - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - - let event = events - .filter(fixtures::filter::messages()) - .next() - .immediately() - .await - .expect("delivered live message"); - - assert_eq!(message, event); -} - -#[tokio::test] -async fn includes_multiple_channels() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let sender = fixtures::login::create(&app).await; - - let channels = [ - fixtures::channel::create(&app, &fixtures::now()).await, - fixtures::channel::create(&app, &fixtures::now()).await, - ]; - - let messages = stream::iter(channels) - .then(|channel| { - let app = app.clone(); - let sender = sender.clone(); - let channel = channel.clone(); - async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await } - }) - .collect::>() - .await; - - // Call the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); - - // Verify the structure of the response. - - let events = events - .filter(fixtures::filter::messages()) - .take(messages.len()) - .collect::>() - .immediately() - .await; - - for message in &messages { - assert!(events.iter().any(|event| { event == message })); - } -} - -#[tokio::test] -async fn sequential_messages() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let sender = fixtures::login::create(&app).await; - - let messages = vec![ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - ]; - - // Call the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); - - // Verify the structure of the response. - - let mut events = events.filter(|event| future::ready(messages.contains(event))); - - // Verify delivery in order - for message in &messages { - let event = events - .next() - .immediately() - .await - .expect("undelivered messages remaining"); - - assert_eq!(message, &event); - } -} - -#[tokio::test] -async fn resumes_from() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let sender = fixtures::login::create(&app).await; - - let initial_message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - - let later_messages = vec![ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - ]; - - // Call the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - - let resume_at = { - // First subscription - let routes::Events(events) = routes::events( - State(app.clone()), - subscriber.clone(), - None, - Query::default(), - ) - .await - .expect("subscribe never fails"); - - let event = events - .filter(fixtures::filter::messages()) - .next() - .immediately() - .await - .expect("delivered events"); - - assert_eq!(initial_message, event); - - event.sequence - }; - - // Resume after disconnect - let routes::Events(resumed) = routes::events( - State(app), - subscriber, - Some(resume_at.into()), - Query::default(), - ) - .await - .expect("subscribe never fails"); - - // Verify the structure of the response. - - let events = resumed - .take(later_messages.len()) - .collect::>() - .immediately() - .await; - - for message in &later_messages { - assert!(events.iter().any(|event| event == message)); - } -} - -// This test verifies a real bug I hit developing the vector-of-sequences -// approach to resuming events. A small omission caused the event IDs in a -// resumed stream to _omit_ channels that were in the original stream until -// those channels also appeared in the resumed stream. -// -// Clients would see something like -// * In the original stream, Cfoo=5,Cbar=8 -// * In the resumed stream, Cfoo=6 (no Cbar sequence number) -// -// Disconnecting and reconnecting a second time, using event IDs from that -// initial period of the first resume attempt, would then cause the second -// resume attempt to restart all other channels from the beginning, and not -// from where the first disconnection happened. -// -// This is a real and valid behaviour for clients! -#[tokio::test] -async fn serial_resume() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let sender = fixtures::login::create(&app).await; - let channel_a = fixtures::channel::create(&app, &fixtures::now()).await; - let channel_b = fixtures::channel::create(&app, &fixtures::now()).await; - - // Call the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; - - let resume_at = { - let initial_messages = [ - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, - ]; - - // First subscription - let routes::Events(events) = routes::events( - State(app.clone()), - subscriber.clone(), - None, - Query::default(), - ) - .await - .expect("subscribe never fails"); - - let events = events - .filter(fixtures::filter::messages()) - .take(initial_messages.len()) - .collect::>() - .immediately() - .await; - - for message in &initial_messages { - assert!(events.iter().any(|event| event == message)); - } - - let event = events.last().expect("this vec is non-empty"); - - event.sequence - }; - - // Resume after disconnect - let resume_at = { - let resume_messages = [ - // Note that channel_b does not appear here. The buggy behaviour - // would be masked if channel_b happened to send a new message - // into the resumed event stream. - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - ]; - - // Second subscription - let routes::Events(events) = routes::events( - State(app.clone()), - subscriber.clone(), - Some(resume_at.into()), - Query::default(), - ) - .await - .expect("subscribe never fails"); - - let events = events - .filter(fixtures::filter::messages()) - .take(resume_messages.len()) - .collect::>() - .immediately() - .await; - - for message in &resume_messages { - assert!(events.iter().any(|event| event == message)); - } - - let event = events.last().expect("this vec is non-empty"); - - event.sequence - }; - - // Resume after disconnect a second time - { - // At this point, we can send on either channel and demonstrate the - // problem. The resume point should before both of these messages, but - // after _all_ prior messages. - let final_messages = [ - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, - ]; - - // Third subscription - let routes::Events(events) = routes::events( - State(app.clone()), - subscriber.clone(), - Some(resume_at.into()), - Query::default(), - ) - .await - .expect("subscribe never fails"); - - let events = events - .filter(fixtures::filter::messages()) - .take(final_messages.len()) - .collect::>() - .immediately() - .await; - - // This set of messages, in particular, _should not_ include any prior - // messages from `initial_messages` or `resume_messages`. - for message in &final_messages { - assert!(events.iter().any(|event| event == message)); - } - }; -} - -#[tokio::test] -async fn terminates_on_token_expiry() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let sender = fixtures::login::create(&app).await; - - // Subscribe via the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber = - fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await; - - let routes::Events(events) = - routes::events(State(app.clone()), subscriber, None, Query::default()) - .await - .expect("subscribe never fails"); - - // Verify the resulting stream's behaviour - - app.logins() - .expire(&fixtures::now()) - .await - .expect("expiring tokens succeeds"); - - // These should not be delivered. - let messages = [ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - ]; - - assert!(events - .filter(|event| future::ready(messages.contains(event))) - .next() - .immediately() - .await - .is_none()); -} - -#[tokio::test] -async fn terminates_on_logout() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let sender = fixtures::login::create(&app).await; - - // Subscribe via the endpoint - - let subscriber_creds = fixtures::login::create_with_password(&app).await; - let subscriber_token = - fixtures::identity::logged_in(&app, &subscriber_creds, &fixtures::now()).await; - let subscriber = - fixtures::identity::from_token(&app, &subscriber_token, &fixtures::now()).await; - - let routes::Events(events) = routes::events( - State(app.clone()), - subscriber.clone(), - None, - Query::default(), - ) - .await - .expect("subscribe never fails"); - - // Verify the resulting stream's behaviour - - app.logins() - .logout(&subscriber.token) - .await - .expect("expiring tokens succeeds"); - - // These should not be delivered. - let messages = [ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - ]; - - assert!(events - .filter(|event| future::ready(messages.contains(event))) - .next() - .immediately() - .await - .is_none()); -} diff --git a/src/events/types.rs b/src/events/types.rs deleted file mode 100644 index 762b6e5..0000000 --- a/src/events/types.rs +++ /dev/null @@ -1,96 +0,0 @@ -use crate::{ - channel, - clock::DateTime, - message, - repo::{channel::Channel, login::Login, message::Message, sequence::Sequence}, -}; - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct ChannelEvent { - #[serde(skip)] - pub sequence: Sequence, - pub at: DateTime, - #[serde(flatten)] - pub data: ChannelEventData, -} - -impl ChannelEvent { - pub fn created(channel: Channel) -> Self { - Self { - at: channel.created_at, - sequence: channel.created_sequence, - data: CreatedEvent { channel }.into(), - } - } - - pub fn channel_id(&self) -> &channel::Id { - match &self.data { - ChannelEventData::Created(event) => &event.channel.id, - ChannelEventData::Message(event) => &event.channel.id, - ChannelEventData::MessageDeleted(event) => &event.channel.id, - ChannelEventData::Deleted(event) => &event.channel, - } - } -} - -impl<'c> From<&'c ChannelEvent> for Sequence { - fn from(event: &'c ChannelEvent) -> Self { - event.sequence - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum ChannelEventData { - Created(CreatedEvent), - Message(MessageEvent), - MessageDeleted(MessageDeletedEvent), - Deleted(DeletedEvent), -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct CreatedEvent { - pub channel: Channel, -} - -impl From for ChannelEventData { - fn from(event: CreatedEvent) -> Self { - Self::Created(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct MessageEvent { - pub channel: Channel, - pub sender: Login, - pub message: Message, -} - -impl From for ChannelEventData { - fn from(event: MessageEvent) -> Self { - Self::Message(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct MessageDeletedEvent { - pub channel: Channel, - pub message: message::Id, -} - -impl From for ChannelEventData { - fn from(event: MessageDeletedEvent) -> Self { - Self::MessageDeleted(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct DeletedEvent { - pub channel: channel::Id, -} - -impl From for ChannelEventData { - fn from(event: DeletedEvent) -> Self { - Self::Deleted(event) - } -} diff --git a/src/lib.rs b/src/lib.rs index 2300071..bbcb314 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,12 +8,12 @@ mod channel; pub mod cli; mod clock; mod error; -mod events; +mod event; mod expire; mod id; mod login; mod message; -mod password; mod repo; #[cfg(test)] mod test; +mod token; diff --git a/src/login/app.rs b/src/login/app.rs index 8ea0a91..60475af 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -6,18 +6,15 @@ use futures::{ }; use sqlx::sqlite::SqlitePool; -use super::{ - broadcaster::Broadcaster, extract::IdentitySecret, repo::auth::Provider as _, token, types, -}; +use super::{broadcaster::Broadcaster, repo::auth::Provider as _, types, Login}; use crate::{ clock::DateTime, - password::Password, + event::Sequence, + login::Password, repo::{ - error::NotFound as _, - login::{Login, Provider as _}, - sequence::{Provider as _, Sequence}, - token::Provider as _, + error::NotFound as _, login::Provider as _, sequence::Provider as _, token::Provider as _, }, + token::{self, Secret}, }; pub struct Logins<'a> { @@ -43,7 +40,7 @@ impl<'a> Logins<'a> { name: &str, password: &Password, login_at: &DateTime, - ) -> Result { + ) -> Result { let mut tx = self.db.begin().await?; let login = if let Some((login, stored_hash)) = tx.auth().for_name(name).await? { @@ -78,7 +75,7 @@ impl<'a> Logins<'a> { pub async fn validate( &self, - secret: &IdentitySecret, + secret: &Secret, used_at: &DateTime, ) -> Result<(token::Id, Login), ValidateError> { let mut tx = self.db.begin().await?; diff --git a/src/login/extract.rs b/src/login/extract.rs index 39dd9e4..c2d97f2 100644 --- a/src/login/extract.rs +++ b/src/login/extract.rs @@ -1,182 +1,15 @@ -use std::fmt; +use axum::{extract::FromRequestParts, http::request::Parts}; -use axum::{ - extract::{FromRequestParts, State}, - http::request::Parts, - response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, -}; -use axum_extra::extract::cookie::{Cookie, CookieJar}; - -use crate::{ - app::App, - clock::RequestedAt, - error::{Internal, Unauthorized}, - login::{app::ValidateError, token}, - repo::login::Login, -}; - -// The usage pattern here - receive the extractor as an argument, return it in -// the response - is heavily modelled after CookieJar's own intended usage. -#[derive(Clone)] -pub struct IdentityToken { - cookies: CookieJar, -} - -impl fmt::Debug for IdentityToken { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("IdentityToken") - .field( - "identity", - &self.cookies.get(IDENTITY_COOKIE).map(|_| "********"), - ) - .finish() - } -} - -impl IdentityToken { - // Creates a new, unpopulated identity token store. - #[cfg(test)] - pub fn new() -> Self { - Self { - cookies: CookieJar::new(), - } - } - - // Get the identity secret sent in the request, if any. If the identity - // was not sent, or if it has previously been [clear]ed, then this will - // return [None]. If the identity has previously been [set], then this - // will return that secret, regardless of what the request originally - // included. - pub fn secret(&self) -> Option { - self.cookies - .get(IDENTITY_COOKIE) - .map(Cookie::value) - .map(IdentitySecret::from) - } - - // Positively set the identity secret, and ensure that it will be sent - // back to the client when this extractor is included in a response. - pub fn set(self, secret: impl Into) -> Self { - let IdentitySecret(secret) = secret.into(); - let identity_cookie = Cookie::build((IDENTITY_COOKIE, secret)) - .http_only(true) - .path("/api/") - .permanent() - .build(); - - Self { - cookies: self.cookies.add(identity_cookie), - } - } - - // Remove the identity secret and ensure that it will be cleared when this - // extractor is included in a response. - pub fn clear(self) -> Self { - Self { - cookies: self.cookies.remove(IDENTITY_COOKIE), - } - } -} - -const IDENTITY_COOKIE: &str = "identity"; - -#[async_trait::async_trait] -impl FromRequestParts for IdentityToken -where - S: Send + Sync, -{ - type Rejection = >::Rejection; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let cookies = CookieJar::from_request_parts(parts, state).await?; - Ok(Self { cookies }) - } -} - -impl IntoResponseParts for IdentityToken { - type Error = ::Error; - - fn into_response_parts(self, res: ResponseParts) -> Result { - let Self { cookies } = self; - cookies.into_response_parts(res) - } -} - -#[derive(sqlx::Type)] -#[sqlx(transparent)] -pub struct IdentitySecret(String); - -impl fmt::Debug for IdentitySecret { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("IdentityToken").field(&"********").finish() - } -} - -impl From for IdentitySecret -where - S: Into, -{ - fn from(value: S) -> Self { - Self(value.into()) - } -} - -#[derive(Clone, Debug)] -pub struct Identity { - pub token: token::Id, - pub login: Login, -} +use super::Login; +use crate::{app::App, token::extract::Identity}; #[async_trait::async_trait] -impl FromRequestParts for Identity { - type Rejection = LoginError; +impl FromRequestParts for Login { + type Rejection = >::Rejection; async fn from_request_parts(parts: &mut Parts, state: &App) -> Result { - // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on - // stable), the following can be replaced: - // - // ``` - // let Ok(identity_token) = IdentityToken::from_request_parts( - // parts, - // state, - // ).await; - // ``` - let identity_token = IdentityToken::from_request_parts(parts, state).await?; - let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?; - - let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?; - - let app = State::::from_request_parts(parts, state).await?; - match app.logins().validate(&secret, &used_at).await { - Ok((token, login)) => Ok(Identity { token, login }), - Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized), - Err(other) => Err(other.into()), - } - } -} - -pub enum LoginError { - Failure(E), - Unauthorized, -} - -impl IntoResponse for LoginError -where - E: IntoResponse, -{ - fn into_response(self) -> Response { - match self { - Self::Unauthorized => Unauthorized.into_response(), - Self::Failure(e) => e.into_response(), - } - } -} + let identity = Identity::from_request_parts(parts, state).await?; -impl From for LoginError -where - E: Into, -{ - fn from(err: E) -> Self { - Self::Failure(err.into()) + Ok(identity.login) } } diff --git a/src/login/mod.rs b/src/login/mod.rs index 0430f4b..91c1821 100644 --- a/src/login/mod.rs +++ b/src/login/mod.rs @@ -2,10 +2,22 @@ pub mod app; pub mod broadcaster; pub mod extract; mod id; +pub mod password; mod repo; mod routes; -pub mod token; pub mod types; -pub use self::id::Id; -pub use self::routes::router; +pub use self::{id::Id, password::Password, routes::router}; + +// This also implements FromRequestParts (see `./extract.rs`). As a result, it +// can be used as an extractor for endpoints that want to require login, or for +// endpoints that need to behave differently depending on whether the client is +// or is not logged in. +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Login { + pub id: Id, + pub name: String, + // The omission of the hashed password is deliberate, to minimize the + // chance that it ends up tangled up in debug output or in some other chunk + // of logic elsewhere. +} diff --git a/src/login/password.rs b/src/login/password.rs new file mode 100644 index 0000000..da3930f --- /dev/null +++ b/src/login/password.rs @@ -0,0 +1,58 @@ +use std::fmt; + +use argon2::Argon2; +use password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}; +use rand_core::OsRng; + +#[derive(Debug, sqlx::Type)] +#[sqlx(transparent)] +pub struct StoredHash(String); + +impl StoredHash { + pub fn verify(&self, password: &Password) -> Result { + let hash = PasswordHash::new(&self.0)?; + + match Argon2::default().verify_password(password.as_bytes(), &hash) { + // Successful authentication, not an error + Ok(()) => Ok(true), + // Unsuccessful authentication, also not an error + Err(password_hash::errors::Error::Password) => Ok(false), + // Password validation failed for some other reason, treat as an error + Err(err) => Err(err), + } + } +} + +#[derive(serde::Deserialize)] +#[serde(transparent)] +pub struct Password(String); + +impl Password { + pub fn hash(&self) -> Result { + let Self(password) = self; + let salt = SaltString::generate(&mut OsRng); + let argon2 = Argon2::default(); + let hash = argon2 + .hash_password(password.as_bytes(), &salt)? + .to_string(); + Ok(StoredHash(hash)) + } + + fn as_bytes(&self) -> &[u8] { + let Self(value) = self; + value.as_bytes() + } +} + +impl fmt::Debug for Password { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Password").field(&"********").finish() + } +} + +#[cfg(test)] +impl From for Password { + fn from(password: String) -> Self { + Self(password) + } +} diff --git a/src/login/repo/auth.rs b/src/login/repo/auth.rs index 9816c5c..b299697 100644 --- a/src/login/repo/auth.rs +++ b/src/login/repo/auth.rs @@ -1,6 +1,6 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; -use crate::{login, password::StoredHash, repo::login::Login}; +use crate::login::{self, password::StoredHash, Login}; pub trait Provider { fn auth(&mut self) -> Auth; diff --git a/src/login/routes.rs b/src/login/routes.rs index ef75871..b571bd5 100644 --- a/src/login/routes.rs +++ b/src/login/routes.rs @@ -10,11 +10,11 @@ use crate::{ app::App, clock::RequestedAt, error::{Internal, Unauthorized}, - password::Password, - repo::login::Login, + login::{Login, Password}, }; -use super::{app, extract::IdentityToken}; +use super::app; +use crate::token::extract::IdentityToken; #[cfg(test)] mod test; diff --git a/src/login/token/id.rs b/src/login/token/id.rs deleted file mode 100644 index 9ef063c..0000000 --- a/src/login/token/id.rs +++ /dev/null @@ -1,27 +0,0 @@ -use std::fmt; - -use crate::id::Id as BaseId; - -// Stable identifier for a token. Prefixed with `T`. -#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] -#[sqlx(transparent)] -#[serde(transparent)] -pub struct Id(BaseId); - -impl From for Id { - fn from(id: BaseId) -> Self { - Self(id) - } -} - -impl Id { - pub fn generate() -> Self { - BaseId::generate("T") - } -} - -impl fmt::Display for Id { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} diff --git a/src/login/token/mod.rs b/src/login/token/mod.rs deleted file mode 100644 index d563a88..0000000 --- a/src/login/token/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod id; - -pub use self::id::Id; diff --git a/src/login/types.rs b/src/login/types.rs index a210977..d53d436 100644 --- a/src/login/types.rs +++ b/src/login/types.rs @@ -1,4 +1,4 @@ -use crate::login::token; +use crate::token; #[derive(Clone, Debug)] pub struct TokenRevoked { diff --git a/src/message/mod.rs b/src/message/mod.rs index d563a88..9a9bf14 100644 --- a/src/message/mod.rs +++ b/src/message/mod.rs @@ -1,3 +1,9 @@ mod id; pub use self::id::Id; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Message { + pub id: Id, + pub body: String, +} diff --git a/src/password.rs b/src/password.rs deleted file mode 100644 index da3930f..0000000 --- a/src/password.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::fmt; - -use argon2::Argon2; -use password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}; -use rand_core::OsRng; - -#[derive(Debug, sqlx::Type)] -#[sqlx(transparent)] -pub struct StoredHash(String); - -impl StoredHash { - pub fn verify(&self, password: &Password) -> Result { - let hash = PasswordHash::new(&self.0)?; - - match Argon2::default().verify_password(password.as_bytes(), &hash) { - // Successful authentication, not an error - Ok(()) => Ok(true), - // Unsuccessful authentication, also not an error - Err(password_hash::errors::Error::Password) => Ok(false), - // Password validation failed for some other reason, treat as an error - Err(err) => Err(err), - } - } -} - -#[derive(serde::Deserialize)] -#[serde(transparent)] -pub struct Password(String); - -impl Password { - pub fn hash(&self) -> Result { - let Self(password) = self; - let salt = SaltString::generate(&mut OsRng); - let argon2 = Argon2::default(); - let hash = argon2 - .hash_password(password.as_bytes(), &salt)? - .to_string(); - Ok(StoredHash(hash)) - } - - fn as_bytes(&self) -> &[u8] { - let Self(value) = self; - value.as_bytes() - } -} - -impl fmt::Debug for Password { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Password").field(&"********").finish() - } -} - -#[cfg(test)] -impl From for Password { - fn from(password: String) -> Self { - Self(password) - } -} diff --git a/src/repo/channel.rs b/src/repo/channel.rs index 9f1d930..18cd81f 100644 --- a/src/repo/channel.rs +++ b/src/repo/channel.rs @@ -1,10 +1,9 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; -use super::sequence::Sequence; use crate::{ - channel::Id, + channel::{Channel, Id}, clock::DateTime, - events::types::{self}, + event::{types, Sequence}, }; pub trait Provider { @@ -19,16 +18,6 @@ impl<'c> Provider for Transaction<'c, Sqlite> { pub struct Channels<'t>(&'t mut SqliteConnection); -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Channel { - pub id: Id, - pub name: String, - #[serde(skip)] - pub created_at: DateTime, - #[serde(skip)] - pub created_sequence: Sequence, -} - impl<'c> Channels<'c> { pub async fn create( &mut self, diff --git a/src/repo/login.rs b/src/repo/login.rs new file mode 100644 index 0000000..d1a02c4 --- /dev/null +++ b/src/repo/login.rs @@ -0,0 +1,50 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::login::{password::StoredHash, Id, Login}; + +pub trait Provider { + fn logins(&mut self) -> Logins; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn logins(&mut self) -> Logins { + Logins(self) + } +} + +pub struct Logins<'t>(&'t mut SqliteConnection); + +impl<'c> Logins<'c> { + pub async fn create( + &mut self, + name: &str, + password_hash: &StoredHash, + ) -> Result { + let id = Id::generate(); + + let login = sqlx::query_as!( + Login, + r#" + insert or fail + into login (id, name, password_hash) + values ($1, $2, $3) + returning + id as "id: Id", + name + "#, + id, + name, + password_hash, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(login) + } +} + +impl<'t> From<&'t mut SqliteConnection> for Logins<'t> { + fn from(tx: &'t mut SqliteConnection) -> Self { + Self(tx) + } +} diff --git a/src/repo/login/extract.rs b/src/repo/login/extract.rs deleted file mode 100644 index ab61106..0000000 --- a/src/repo/login/extract.rs +++ /dev/null @@ -1,15 +0,0 @@ -use axum::{extract::FromRequestParts, http::request::Parts}; - -use super::Login; -use crate::{app::App, login::extract::Identity}; - -#[async_trait::async_trait] -impl FromRequestParts for Login { - type Rejection = >::Rejection; - - async fn from_request_parts(parts: &mut Parts, state: &App) -> Result { - let identity = Identity::from_request_parts(parts, state).await?; - - Ok(identity.login) - } -} diff --git a/src/repo/login/mod.rs b/src/repo/login/mod.rs deleted file mode 100644 index 4ff7a96..0000000 --- a/src/repo/login/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod extract; -mod store; - -pub use self::store::{Login, Provider}; diff --git a/src/repo/login/store.rs b/src/repo/login/store.rs deleted file mode 100644 index 47d1a7c..0000000 --- a/src/repo/login/store.rs +++ /dev/null @@ -1,63 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{login::Id, password::StoredHash}; - -pub trait Provider { - fn logins(&mut self) -> Logins; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn logins(&mut self) -> Logins { - Logins(self) - } -} - -pub struct Logins<'t>(&'t mut SqliteConnection); - -// This also implements FromRequestParts (see `./extract.rs`). As a result, it -// can be used as an extractor for endpoints that want to require login, or for -// endpoints that need to behave differently depending on whether the client is -// or is not logged in. -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Login { - pub id: Id, - pub name: String, - // The omission of the hashed password is deliberate, to minimize the - // chance that it ends up tangled up in debug output or in some other chunk - // of logic elsewhere. -} - -impl<'c> Logins<'c> { - pub async fn create( - &mut self, - name: &str, - password_hash: &StoredHash, - ) -> Result { - let id = Id::generate(); - - let login = sqlx::query_as!( - Login, - r#" - insert or fail - into login (id, name, password_hash) - values ($1, $2, $3) - returning - id as "id: Id", - name - "#, - id, - name, - password_hash, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(login) - } -} - -impl<'t> From<&'t mut SqliteConnection> for Logins<'t> { - fn from(tx: &'t mut SqliteConnection) -> Self { - Self(tx) - } -} diff --git a/src/repo/message.rs b/src/repo/message.rs deleted file mode 100644 index acde3ea..0000000 --- a/src/repo/message.rs +++ /dev/null @@ -1,7 +0,0 @@ -use crate::message::Id; - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Message { - pub id: Id, - pub body: String, -} diff --git a/src/repo/mod.rs b/src/repo/mod.rs index 8f271f4..69ad82c 100644 --- a/src/repo/mod.rs +++ b/src/repo/mod.rs @@ -1,7 +1,6 @@ pub mod channel; pub mod error; pub mod login; -pub mod message; pub mod pool; pub mod sequence; pub mod token; diff --git a/src/repo/sequence.rs b/src/repo/sequence.rs index c47b41c..c985869 100644 --- a/src/repo/sequence.rs +++ b/src/repo/sequence.rs @@ -1,7 +1,7 @@ -use std::fmt; - use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; +use crate::event::Sequence; + pub trait Provider { fn sequence(&mut self) -> Sequences; } @@ -42,26 +42,3 @@ impl<'c> Sequences<'c> { Ok(next) } } - -#[derive( - Clone, - Copy, - Debug, - Eq, - Ord, - PartialEq, - PartialOrd, - serde::Deserialize, - serde::Serialize, - sqlx::Type, -)] -#[serde(transparent)] -#[sqlx(transparent)] -pub struct Sequence(i64); - -impl fmt::Display for Sequence { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self(value) = self; - value.fmt(f) - } -} diff --git a/src/repo/token.rs b/src/repo/token.rs index 79e5c54..5f64dac 100644 --- a/src/repo/token.rs +++ b/src/repo/token.rs @@ -1,10 +1,10 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use uuid::Uuid; -use super::login::Login; use crate::{ clock::DateTime, - login::{self, extract::IdentitySecret, token::Id}, + login::{self, Login}, + token::{Id, Secret}, }; pub trait Provider { @@ -26,7 +26,7 @@ impl<'c> Tokens<'c> { &mut self, login: &Login, issued_at: &DateTime, - ) -> Result { + ) -> Result { let id = Id::generate(); let secret = Uuid::new_v4().to_string(); @@ -35,7 +35,7 @@ impl<'c> Tokens<'c> { insert into token (id, secret, login, issued_at, last_used_at) values ($1, $2, $3, $4, $4) - returning secret as "secret!: IdentitySecret" + returning secret as "secret!: Secret" "#, id, secret, @@ -103,7 +103,7 @@ impl<'c> Tokens<'c> { // timestamp will be set to `used_at`. pub async fn validate( &mut self, - secret: &IdentitySecret, + secret: &Secret, used_at: &DateTime, ) -> Result<(Id, Login), sqlx::Error> { // I would use `update … returning` to do this in one query, but diff --git a/src/test/fixtures/channel.rs b/src/test/fixtures/channel.rs index 8744470..b678717 100644 --- a/src/test/fixtures/channel.rs +++ b/src/test/fixtures/channel.rs @@ -4,7 +4,7 @@ use faker_rand::{ }; use rand; -use crate::{app::App, clock::RequestedAt, repo::channel::Channel}; +use crate::{app::App, channel::Channel, clock::RequestedAt}; pub async fn create(app: &App, created_at: &RequestedAt) -> Channel { let name = propose(); diff --git a/src/test/fixtures/filter.rs b/src/test/fixtures/filter.rs index c31fa58..d1939a5 100644 --- a/src/test/fixtures/filter.rs +++ b/src/test/fixtures/filter.rs @@ -1,6 +1,6 @@ use futures::future; -use crate::events::types; +use crate::event::types; pub fn messages() -> impl FnMut(&types::ChannelEvent) -> future::Ready { |event| future::ready(matches!(event.data, types::ChannelEventData::Message(_))) diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs index 633fb8a..9e8e403 100644 --- a/src/test/fixtures/identity.rs +++ b/src/test/fixtures/identity.rs @@ -3,8 +3,11 @@ use uuid::Uuid; use crate::{ app::App, clock::RequestedAt, - login::extract::{Identity, IdentitySecret, IdentityToken}, - password::Password, + login::Password, + token::{ + extract::{Identity, IdentityToken}, + Secret, + }, }; pub fn not_logged_in() -> IdentityToken { @@ -38,7 +41,7 @@ pub async fn identity(app: &App, login: &(String, Password), issued_at: &Request from_token(app, &secret, issued_at).await } -pub fn secret(identity: &IdentityToken) -> IdentitySecret { +pub fn secret(identity: &IdentityToken) -> Secret { identity.secret().expect("identity contained a secret") } diff --git a/src/test/fixtures/login.rs b/src/test/fixtures/login.rs index d6a321b..00c2789 100644 --- a/src/test/fixtures/login.rs +++ b/src/test/fixtures/login.rs @@ -3,8 +3,7 @@ use uuid::Uuid; use crate::{ app::App, - password::Password, - repo::login::{self, Login}, + login::{self, Login, Password}, }; pub async fn create_with_password(app: &App) -> (String, Password) { diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs index bfca8cd..fd50887 100644 --- a/src/test/fixtures/message.rs +++ b/src/test/fixtures/message.rs @@ -1,11 +1,6 @@ use faker_rand::lorem::Paragraphs; -use crate::{ - app::App, - clock::RequestedAt, - events::types, - repo::{channel::Channel, login::Login}, -}; +use crate::{app::App, channel::Channel, clock::RequestedAt, event::types, login::Login}; pub async fn send( app: &App, diff --git a/src/token/extract/identity.rs b/src/token/extract/identity.rs new file mode 100644 index 0000000..42c7c60 --- /dev/null +++ b/src/token/extract/identity.rs @@ -0,0 +1,75 @@ +use axum::{ + extract::{FromRequestParts, State}, + http::request::Parts, + response::{IntoResponse, Response}, +}; + +use super::IdentityToken; + +use crate::{ + app::App, + clock::RequestedAt, + error::{Internal, Unauthorized}, + login::{app::ValidateError, Login}, + token, +}; + +#[derive(Clone, Debug)] +pub struct Identity { + pub token: token::Id, + pub login: Login, +} + +#[async_trait::async_trait] +impl FromRequestParts for Identity { + type Rejection = LoginError; + + async fn from_request_parts(parts: &mut Parts, state: &App) -> Result { + // After Rust 1.82 (and #[feature(min_exhaustive_patterns)] lands on + // stable), the following can be replaced: + // + // ``` + // let Ok(identity_token) = IdentityToken::from_request_parts( + // parts, + // state, + // ).await; + // ``` + let identity_token = IdentityToken::from_request_parts(parts, state).await?; + let RequestedAt(used_at) = RequestedAt::from_request_parts(parts, state).await?; + + let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?; + + let app = State::::from_request_parts(parts, state).await?; + match app.logins().validate(&secret, &used_at).await { + Ok((token, login)) => Ok(Identity { token, login }), + Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized), + Err(other) => Err(other.into()), + } + } +} + +pub enum LoginError { + Failure(E), + Unauthorized, +} + +impl IntoResponse for LoginError +where + E: IntoResponse, +{ + fn into_response(self) -> Response { + match self { + Self::Unauthorized => Unauthorized.into_response(), + Self::Failure(e) => e.into_response(), + } + } +} + +impl From for LoginError +where + E: Into, +{ + fn from(err: E) -> Self { + Self::Failure(err.into()) + } +} diff --git a/src/token/extract/identity_token.rs b/src/token/extract/identity_token.rs new file mode 100644 index 0000000..0a47a43 --- /dev/null +++ b/src/token/extract/identity_token.rs @@ -0,0 +1,94 @@ +use std::fmt; + +use axum::{ + extract::FromRequestParts, + http::request::Parts, + response::{IntoResponseParts, ResponseParts}, +}; +use axum_extra::extract::cookie::{Cookie, CookieJar}; + +use crate::token::Secret; + +// The usage pattern here - receive the extractor as an argument, return it in +// the response - is heavily modelled after CookieJar's own intended usage. +#[derive(Clone)] +pub struct IdentityToken { + cookies: CookieJar, +} + +impl fmt::Debug for IdentityToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IdentityToken") + .field("identity", &self.secret()) + .finish() + } +} + +impl IdentityToken { + // Creates a new, unpopulated identity token store. + #[cfg(test)] + pub fn new() -> Self { + Self { + cookies: CookieJar::new(), + } + } + + // Get the identity secret sent in the request, if any. If the identity + // was not sent, or if it has previously been [clear]ed, then this will + // return [None]. If the identity has previously been [set], then this + // will return that secret, regardless of what the request originally + // included. + pub fn secret(&self) -> Option { + self.cookies + .get(IDENTITY_COOKIE) + .map(Cookie::value) + .map(Secret::from) + } + + // Positively set the identity secret, and ensure that it will be sent + // back to the client when this extractor is included in a response. + pub fn set(self, secret: impl Into) -> Self { + let secret = secret.into().reveal(); + let identity_cookie = Cookie::build((IDENTITY_COOKIE, secret)) + .http_only(true) + .path("/api/") + .permanent() + .build(); + + Self { + cookies: self.cookies.add(identity_cookie), + } + } + + // Remove the identity secret and ensure that it will be cleared when this + // extractor is included in a response. + pub fn clear(self) -> Self { + Self { + cookies: self.cookies.remove(IDENTITY_COOKIE), + } + } +} + +const IDENTITY_COOKIE: &str = "identity"; + +#[async_trait::async_trait] +impl FromRequestParts for IdentityToken +where + S: Send + Sync, +{ + type Rejection = >::Rejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let cookies = CookieJar::from_request_parts(parts, state).await?; + Ok(Self { cookies }) + } +} + +impl IntoResponseParts for IdentityToken { + type Error = ::Error; + + fn into_response_parts(self, res: ResponseParts) -> Result { + let Self { cookies } = self; + cookies.into_response_parts(res) + } +} diff --git a/src/token/extract/mod.rs b/src/token/extract/mod.rs new file mode 100644 index 0000000..b4800ae --- /dev/null +++ b/src/token/extract/mod.rs @@ -0,0 +1,4 @@ +mod identity; +mod identity_token; + +pub use self::{identity::Identity, identity_token::IdentityToken}; diff --git a/src/token/id.rs b/src/token/id.rs new file mode 100644 index 0000000..9ef063c --- /dev/null +++ b/src/token/id.rs @@ -0,0 +1,27 @@ +use std::fmt; + +use crate::id::Id as BaseId; + +// Stable identifier for a token. Prefixed with `T`. +#[derive(Clone, Debug, Eq, Hash, PartialEq, sqlx::Type, serde::Deserialize, serde::Serialize)] +#[sqlx(transparent)] +#[serde(transparent)] +pub struct Id(BaseId); + +impl From for Id { + fn from(id: BaseId) -> Self { + Self(id) + } +} + +impl Id { + pub fn generate() -> Self { + BaseId::generate("T") + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} diff --git a/src/token/mod.rs b/src/token/mod.rs new file mode 100644 index 0000000..c98b8c2 --- /dev/null +++ b/src/token/mod.rs @@ -0,0 +1,5 @@ +pub mod extract; +mod id; +mod secret; + +pub use self::{id::Id, secret::Secret}; diff --git a/src/token/secret.rs b/src/token/secret.rs new file mode 100644 index 0000000..28c93bb --- /dev/null +++ b/src/token/secret.rs @@ -0,0 +1,27 @@ +use std::fmt; + +#[derive(sqlx::Type)] +#[sqlx(transparent)] +pub struct Secret(String); + +impl Secret { + pub fn reveal(self) -> String { + let Self(secret) = self; + secret + } +} + +impl fmt::Debug for Secret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("IdentityToken").field(&"********").finish() + } +} + +impl From for Secret +where + S: Into, +{ + fn from(value: S) -> Self { + Self(value.into()) + } +} -- cgit v1.2.3 From 6f07e6869bbf62903ac83c9bc061e7bde997e6a8 Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Wed, 2 Oct 2024 01:10:09 -0400 Subject: Retire top-level `repo`. This helped me discover an organizational scheme I like more. --- src/channel/app.rs | 6 +- src/channel/mod.rs | 1 + src/channel/repo.rs | 165 +++++++++++++++++++++++++++++++++++++++++++++ src/cli.rs | 4 +- src/db.rs | 42 ++++++++++++ src/event/app.rs | 6 +- src/event/repo/mod.rs | 3 + src/event/repo/sequence.rs | 44 ++++++++++++ src/lib.rs | 2 +- src/login/app.rs | 2 +- src/repo/channel.rs | 165 --------------------------------------------- src/repo/error.rs | 23 ------- src/repo/mod.rs | 4 -- src/repo/pool.rs | 18 ----- src/repo/sequence.rs | 44 ------------ src/test/fixtures/mod.rs | 4 +- src/token/app.rs | 2 +- 17 files changed, 267 insertions(+), 268 deletions(-) create mode 100644 src/channel/repo.rs create mode 100644 src/db.rs create mode 100644 src/event/repo/sequence.rs delete mode 100644 src/repo/channel.rs delete mode 100644 src/repo/error.rs delete mode 100644 src/repo/mod.rs delete mode 100644 src/repo/pool.rs delete mode 100644 src/repo/sequence.rs (limited to 'src/channel/mod.rs') diff --git a/src/channel/app.rs b/src/channel/app.rs index 1422651..ef0a63f 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -2,11 +2,9 @@ use chrono::TimeDelta; use sqlx::sqlite::SqlitePool; use crate::{ - channel::Channel, + channel::{repo::Provider as _, Channel}, clock::DateTime, - event::Sequence, - event::{broadcaster::Broadcaster, types::ChannelEvent}, - repo::{channel::Provider as _, sequence::Provider as _}, + event::{broadcaster::Broadcaster, repo::Provider as _, types::ChannelEvent, Sequence}, }; pub struct Channels<'a> { diff --git a/src/channel/mod.rs b/src/channel/mod.rs index 02d0ed4..2672084 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -2,6 +2,7 @@ use crate::{clock::DateTime, event::Sequence}; pub mod app; mod id; +pub mod repo; mod routes; pub use self::{id::Id, routes::router}; diff --git a/src/channel/repo.rs b/src/channel/repo.rs new file mode 100644 index 0000000..18cd81f --- /dev/null +++ b/src/channel/repo.rs @@ -0,0 +1,165 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::{ + channel::{Channel, Id}, + clock::DateTime, + event::{types, Sequence}, +}; + +pub trait Provider { + fn channels(&mut self) -> Channels; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn channels(&mut self) -> Channels { + Channels(self) + } +} + +pub struct Channels<'t>(&'t mut SqliteConnection); + +impl<'c> Channels<'c> { + pub async fn create( + &mut self, + name: &str, + created_at: &DateTime, + created_sequence: Sequence, + ) -> Result { + let id = Id::generate(); + let channel = sqlx::query_as!( + Channel, + r#" + insert + into channel (id, name, created_at, created_sequence) + values ($1, $2, $3, $4) + returning + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + "#, + id, + name, + created_at, + created_sequence, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(channel) + } + + pub async fn by_id(&mut self, channel: &Id) -> Result { + let channel = sqlx::query_as!( + Channel, + r#" + select + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + from channel + where id = $1 + "#, + channel, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(channel) + } + + pub async fn all( + &mut self, + resume_point: Option, + ) -> Result, sqlx::Error> { + let channels = sqlx::query_as!( + Channel, + r#" + select + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + from channel + where coalesce(created_sequence <= $1, true) + order by channel.name + "#, + resume_point, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } + + pub async fn replay( + &mut self, + resume_at: Option, + ) -> Result, sqlx::Error> { + let channels = sqlx::query_as!( + Channel, + r#" + select + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" + from channel + where coalesce(created_sequence > $1, true) + "#, + resume_at, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } + + pub async fn delete( + &mut self, + channel: &Channel, + deleted_at: &DateTime, + deleted_sequence: Sequence, + ) -> Result { + let channel = channel.id.clone(); + sqlx::query_scalar!( + r#" + delete from channel + where id = $1 + returning 1 as "row: i64" + "#, + channel, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(types::ChannelEvent { + sequence: deleted_sequence, + at: *deleted_at, + data: types::DeletedEvent { channel }.into(), + }) + } + + pub async fn expired(&mut self, expired_at: &DateTime) -> Result, sqlx::Error> { + let channels = sqlx::query_as!( + Channel, + r#" + select + channel.id as "id: Id", + channel.name, + channel.created_at as "created_at: DateTime", + channel.created_sequence as "created_sequence: Sequence" + from channel + left join message + where created_at < $1 + and message.id is null + "#, + expired_at, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } +} diff --git a/src/cli.rs b/src/cli.rs index ee95ea6..893fae2 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -10,7 +10,7 @@ use clap::Parser; use sqlx::sqlite::SqlitePool; use tokio::net; -use crate::{app::App, channel, clock, event, expire, login, repo::pool}; +use crate::{app::App, channel, clock, db, event, expire, login}; /// Command-line entry point for running the `hi` server. /// @@ -100,7 +100,7 @@ impl Args { } async fn pool(&self) -> sqlx::Result { - pool::prepare(&self.database_url).await + db::prepare(&self.database_url).await } } diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..93a1169 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,42 @@ +use std::str::FromStr; + +use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; + +pub async fn prepare(url: &str) -> sqlx::Result { + let pool = create(url).await?; + sqlx::migrate!().run(&pool).await?; + Ok(pool) +} + +async fn create(database_url: &str) -> sqlx::Result { + let options = SqliteConnectOptions::from_str(database_url)? + .create_if_missing(true) + .optimize_on_close(true, /* analysis_limit */ None); + + let pool = SqlitePoolOptions::new().connect_with(options).await?; + Ok(pool) +} + +pub trait NotFound { + type Ok; + fn not_found(self, map: F) -> Result + where + E: From, + F: FnOnce() -> E; +} + +impl NotFound for Result { + type Ok = T; + + fn not_found(self, map: F) -> Result + where + E: From, + F: FnOnce() -> E, + { + match self { + Err(sqlx::Error::RowNotFound) => Err(map()), + Err(other) => Err(other.into()), + Ok(value) => Ok(value), + } + } +} diff --git a/src/event/app.rs b/src/event/app.rs index b5f2ecc..3d35f1a 100644 --- a/src/event/app.rs +++ b/src/event/app.rs @@ -12,11 +12,11 @@ use super::{ types::{self, ChannelEvent}, }; use crate::{ - channel, + channel::{self, repo::Provider as _}, clock::DateTime, - event::Sequence, + db::NotFound as _, + event::{repo::Provider as _, Sequence}, login::Login, - repo::{channel::Provider as _, error::NotFound as _, sequence::Provider as _}, }; pub struct Events<'a> { diff --git a/src/event/repo/mod.rs b/src/event/repo/mod.rs index e216a50..cee840c 100644 --- a/src/event/repo/mod.rs +++ b/src/event/repo/mod.rs @@ -1 +1,4 @@ pub mod message; +mod sequence; + +pub use self::sequence::Provider; diff --git a/src/event/repo/sequence.rs b/src/event/repo/sequence.rs new file mode 100644 index 0000000..c985869 --- /dev/null +++ b/src/event/repo/sequence.rs @@ -0,0 +1,44 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::event::Sequence; + +pub trait Provider { + fn sequence(&mut self) -> Sequences; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn sequence(&mut self) -> Sequences { + Sequences(self) + } +} + +pub struct Sequences<'t>(&'t mut SqliteConnection); + +impl<'c> Sequences<'c> { + pub async fn next(&mut self) -> Result { + let next = sqlx::query_scalar!( + r#" + update event_sequence + set last_value = last_value + 1 + returning last_value as "next_value: Sequence" + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } + + pub async fn current(&mut self) -> Result { + let next = sqlx::query_scalar!( + r#" + select last_value as "last_value: Sequence" + from event_sequence + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } +} diff --git a/src/lib.rs b/src/lib.rs index bbcb314..8ec13da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,13 +7,13 @@ mod broadcast; mod channel; pub mod cli; mod clock; +mod db; mod error; mod event; mod expire; mod id; mod login; mod message; -mod repo; #[cfg(test)] mod test; mod token; diff --git a/src/login/app.rs b/src/login/app.rs index 69c1055..15adb31 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -1,6 +1,6 @@ use sqlx::sqlite::SqlitePool; -use crate::{event::Sequence, repo::sequence::Provider as _}; +use crate::event::{repo::Provider as _, Sequence}; #[cfg(test)] use super::{repo::Provider as _, Login, Password}; diff --git a/src/repo/channel.rs b/src/repo/channel.rs deleted file mode 100644 index 18cd81f..0000000 --- a/src/repo/channel.rs +++ /dev/null @@ -1,165 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{ - channel::{Channel, Id}, - clock::DateTime, - event::{types, Sequence}, -}; - -pub trait Provider { - fn channels(&mut self) -> Channels; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn channels(&mut self) -> Channels { - Channels(self) - } -} - -pub struct Channels<'t>(&'t mut SqliteConnection); - -impl<'c> Channels<'c> { - pub async fn create( - &mut self, - name: &str, - created_at: &DateTime, - created_sequence: Sequence, - ) -> Result { - let id = Id::generate(); - let channel = sqlx::query_as!( - Channel, - r#" - insert - into channel (id, name, created_at, created_sequence) - values ($1, $2, $3, $4) - returning - id as "id: Id", - name, - created_at as "created_at: DateTime", - created_sequence as "created_sequence: Sequence" - "#, - id, - name, - created_at, - created_sequence, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(channel) - } - - pub async fn by_id(&mut self, channel: &Id) -> Result { - let channel = sqlx::query_as!( - Channel, - r#" - select - id as "id: Id", - name, - created_at as "created_at: DateTime", - created_sequence as "created_sequence: Sequence" - from channel - where id = $1 - "#, - channel, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(channel) - } - - pub async fn all( - &mut self, - resume_point: Option, - ) -> Result, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, - r#" - select - id as "id: Id", - name, - created_at as "created_at: DateTime", - created_sequence as "created_sequence: Sequence" - from channel - where coalesce(created_sequence <= $1, true) - order by channel.name - "#, - resume_point, - ) - .fetch_all(&mut *self.0) - .await?; - - Ok(channels) - } - - pub async fn replay( - &mut self, - resume_at: Option, - ) -> Result, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, - r#" - select - id as "id: Id", - name, - created_at as "created_at: DateTime", - created_sequence as "created_sequence: Sequence" - from channel - where coalesce(created_sequence > $1, true) - "#, - resume_at, - ) - .fetch_all(&mut *self.0) - .await?; - - Ok(channels) - } - - pub async fn delete( - &mut self, - channel: &Channel, - deleted_at: &DateTime, - deleted_sequence: Sequence, - ) -> Result { - let channel = channel.id.clone(); - sqlx::query_scalar!( - r#" - delete from channel - where id = $1 - returning 1 as "row: i64" - "#, - channel, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(types::ChannelEvent { - sequence: deleted_sequence, - at: *deleted_at, - data: types::DeletedEvent { channel }.into(), - }) - } - - pub async fn expired(&mut self, expired_at: &DateTime) -> Result, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, - r#" - select - channel.id as "id: Id", - channel.name, - channel.created_at as "created_at: DateTime", - channel.created_sequence as "created_sequence: Sequence" - from channel - left join message - where created_at < $1 - and message.id is null - "#, - expired_at, - ) - .fetch_all(&mut *self.0) - .await?; - - Ok(channels) - } -} diff --git a/src/repo/error.rs b/src/repo/error.rs deleted file mode 100644 index a5961e2..0000000 --- a/src/repo/error.rs +++ /dev/null @@ -1,23 +0,0 @@ -pub trait NotFound { - type Ok; - fn not_found(self, map: F) -> Result - where - E: From, - F: FnOnce() -> E; -} - -impl NotFound for Result { - type Ok = T; - - fn not_found(self, map: F) -> Result - where - E: From, - F: FnOnce() -> E, - { - match self { - Err(sqlx::Error::RowNotFound) => Err(map()), - Err(other) => Err(other.into()), - Ok(value) => Ok(value), - } - } -} diff --git a/src/repo/mod.rs b/src/repo/mod.rs deleted file mode 100644 index 7abd46b..0000000 --- a/src/repo/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod channel; -pub mod error; -pub mod pool; -pub mod sequence; diff --git a/src/repo/pool.rs b/src/repo/pool.rs deleted file mode 100644 index b4aa6fc..0000000 --- a/src/repo/pool.rs +++ /dev/null @@ -1,18 +0,0 @@ -use std::str::FromStr; - -use sqlx::sqlite::{SqliteConnectOptions, SqlitePool, SqlitePoolOptions}; - -pub async fn prepare(url: &str) -> sqlx::Result { - let pool = create(url).await?; - sqlx::migrate!().run(&pool).await?; - Ok(pool) -} - -async fn create(database_url: &str) -> sqlx::Result { - let options = SqliteConnectOptions::from_str(database_url)? - .create_if_missing(true) - .optimize_on_close(true, /* analysis_limit */ None); - - let pool = SqlitePoolOptions::new().connect_with(options).await?; - Ok(pool) -} diff --git a/src/repo/sequence.rs b/src/repo/sequence.rs deleted file mode 100644 index c985869..0000000 --- a/src/repo/sequence.rs +++ /dev/null @@ -1,44 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::event::Sequence; - -pub trait Provider { - fn sequence(&mut self) -> Sequences; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn sequence(&mut self) -> Sequences { - Sequences(self) - } -} - -pub struct Sequences<'t>(&'t mut SqliteConnection); - -impl<'c> Sequences<'c> { - pub async fn next(&mut self) -> Result { - let next = sqlx::query_scalar!( - r#" - update event_sequence - set last_value = last_value + 1 - returning last_value as "next_value: Sequence" - "#, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(next) - } - - pub async fn current(&mut self) -> Result { - let next = sqlx::query_scalar!( - r#" - select last_value as "last_value: Sequence" - from event_sequence - "#, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(next) - } -} diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs index d1dd0c3..76467ab 100644 --- a/src/test/fixtures/mod.rs +++ b/src/test/fixtures/mod.rs @@ -1,6 +1,6 @@ use chrono::{TimeDelta, Utc}; -use crate::{app::App, clock::RequestedAt, repo::pool}; +use crate::{app::App, clock::RequestedAt, db}; pub mod channel; pub mod filter; @@ -10,7 +10,7 @@ pub mod login; pub mod message; pub async fn scratch_app() -> App { - let pool = pool::prepare("sqlite::memory:") + let pool = db::prepare("sqlite::memory:") .await .expect("setting up in-memory sqlite database"); App::from(pool) diff --git a/src/token/app.rs b/src/token/app.rs index 1477a9f..030ec69 100644 --- a/src/token/app.rs +++ b/src/token/app.rs @@ -11,8 +11,8 @@ use super::{ }; use crate::{ clock::DateTime, + db::NotFound as _, login::{repo::Provider as _, Login, Password}, - repo::error::NotFound as _, }; pub struct Tokens<'a> { -- cgit v1.2.3 From 469613872f6fb19f4579b387e19b2bc38fa52f51 Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Wed, 2 Oct 2024 01:31:43 -0400 Subject: Package up common event fields as Instant --- src/channel/app.rs | 11 +++--- src/channel/mod.rs | 6 ++-- src/channel/repo.rs | 74 ++++++++++++++++++++++++++------------ src/channel/routes/test/on_send.rs | 2 +- src/event/app.rs | 18 +++++----- src/event/mod.rs | 9 +++++ src/event/repo/message.rs | 42 +++++++++++++--------- src/event/repo/sequence.rs | 12 +++++-- src/event/routes.rs | 2 +- src/event/routes/test.rs | 8 ++--- src/event/types.rs | 22 +++--------- 11 files changed, 121 insertions(+), 85 deletions(-) (limited to 'src/channel/mod.rs') diff --git a/src/channel/app.rs b/src/channel/app.rs index ef0a63f..b7e3a10 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -19,10 +19,10 @@ impl<'a> Channels<'a> { pub async fn create(&self, name: &str, created_at: &DateTime) -> Result { let mut tx = self.db.begin().await?; - let created_sequence = tx.sequence().next().await?; + let created = tx.sequence().next(created_at).await?; let channel = tx .channels() - .create(name, created_at, created_sequence) + .create(name, &created) .await .map_err(|err| CreateError::from_duplicate_name(err, name))?; tx.commit().await?; @@ -50,11 +50,8 @@ impl<'a> Channels<'a> { let mut events = Vec::with_capacity(expired.len()); for channel in expired { - let deleted_sequence = tx.sequence().next().await?; - let event = tx - .channels() - .delete(&channel, relative_to, deleted_sequence) - .await?; + let deleted = tx.sequence().next(relative_to).await?; + let event = tx.channels().delete(&channel, &deleted).await?; events.push(event); } diff --git a/src/channel/mod.rs b/src/channel/mod.rs index 2672084..4baa7e3 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -1,4 +1,4 @@ -use crate::{clock::DateTime, event::Sequence}; +use crate::event::Instant; pub mod app; mod id; @@ -12,7 +12,5 @@ pub struct Channel { pub id: Id, pub name: String, #[serde(skip)] - pub created_at: DateTime, - #[serde(skip)] - pub created_sequence: Sequence, + pub created: Instant, } diff --git a/src/channel/repo.rs b/src/channel/repo.rs index 18cd81f..c000b56 100644 --- a/src/channel/repo.rs +++ b/src/channel/repo.rs @@ -3,7 +3,7 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ channel::{Channel, Id}, clock::DateTime, - event::{types, Sequence}, + event::{types, Instant, Sequence}, }; pub trait Provider { @@ -19,15 +19,9 @@ impl<'c> Provider for Transaction<'c, Sqlite> { pub struct Channels<'t>(&'t mut SqliteConnection); impl<'c> Channels<'c> { - pub async fn create( - &mut self, - name: &str, - created_at: &DateTime, - created_sequence: Sequence, - ) -> Result { + pub async fn create(&mut self, name: &str, created: &Instant) -> Result { let id = Id::generate(); - let channel = sqlx::query_as!( - Channel, + let channel = sqlx::query!( r#" insert into channel (id, name, created_at, created_sequence) @@ -40,9 +34,17 @@ impl<'c> Channels<'c> { "#, id, name, - created_at, - created_sequence, + created.at, + created.sequence, ) + .map(|row| Channel { + id: row.id, + name: row.name, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + }) .fetch_one(&mut *self.0) .await?; @@ -50,8 +52,7 @@ impl<'c> Channels<'c> { } pub async fn by_id(&mut self, channel: &Id) -> Result { - let channel = sqlx::query_as!( - Channel, + let channel = sqlx::query!( r#" select id as "id: Id", @@ -63,6 +64,14 @@ impl<'c> Channels<'c> { "#, channel, ) + .map(|row| Channel { + id: row.id, + name: row.name, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + }) .fetch_one(&mut *self.0) .await?; @@ -73,8 +82,7 @@ impl<'c> Channels<'c> { &mut self, resume_point: Option, ) -> Result, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, + let channels = sqlx::query!( r#" select id as "id: Id", @@ -87,6 +95,14 @@ impl<'c> Channels<'c> { "#, resume_point, ) + .map(|row| Channel { + id: row.id, + name: row.name, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + }) .fetch_all(&mut *self.0) .await?; @@ -97,8 +113,7 @@ impl<'c> Channels<'c> { &mut self, resume_at: Option, ) -> Result, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, + let channels = sqlx::query!( r#" select id as "id: Id", @@ -110,6 +125,14 @@ impl<'c> Channels<'c> { "#, resume_at, ) + .map(|row| Channel { + id: row.id, + name: row.name, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + }) .fetch_all(&mut *self.0) .await?; @@ -119,8 +142,7 @@ impl<'c> Channels<'c> { pub async fn delete( &mut self, channel: &Channel, - deleted_at: &DateTime, - deleted_sequence: Sequence, + deleted: &Instant, ) -> Result { let channel = channel.id.clone(); sqlx::query_scalar!( @@ -135,15 +157,13 @@ impl<'c> Channels<'c> { .await?; Ok(types::ChannelEvent { - sequence: deleted_sequence, - at: *deleted_at, + instant: *deleted, data: types::DeletedEvent { channel }.into(), }) } pub async fn expired(&mut self, expired_at: &DateTime) -> Result, sqlx::Error> { - let channels = sqlx::query_as!( - Channel, + let channels = sqlx::query!( r#" select channel.id as "id: Id", @@ -157,6 +177,14 @@ impl<'c> Channels<'c> { "#, expired_at, ) + .map(|row| Channel { + id: row.id, + name: row.name, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + }) .fetch_all(&mut *self.0) .await?; diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index 6f844cd..33ec3b7 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -52,7 +52,7 @@ async fn messages_in_order() { let events = events.collect::>().immediately().await; for ((sent_at, message), event) in requests.into_iter().zip(events) { - assert_eq!(*sent_at, event.at); + assert_eq!(*sent_at, event.instant.at); assert!(matches!( event.data, types::ChannelEventData::Message(event_message) diff --git a/src/event/app.rs b/src/event/app.rs index 3d35f1a..5e9e79a 100644 --- a/src/event/app.rs +++ b/src/event/app.rs @@ -42,10 +42,10 @@ impl<'a> Events<'a> { .by_id(channel) .await .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; - let sent_sequence = tx.sequence().next().await?; + let sent = tx.sequence().next(sent_at).await?; let event = tx .message_events() - .create(login, &channel, sent_at, sent_sequence, body) + .create(login, &channel, &sent, body) .await?; tx.commit().await?; @@ -62,10 +62,10 @@ impl<'a> Events<'a> { let mut events = Vec::with_capacity(expired.len()); for (channel, message) in expired { - let deleted_sequence = tx.sequence().next().await?; + let deleted = tx.sequence().next(relative_to).await?; let event = tx .message_events() - .delete(&channel, &message, relative_to, deleted_sequence) + .delete(&channel, &message, &deleted) .await?; events.push(event); } @@ -93,7 +93,9 @@ impl<'a> Events<'a> { let channel_events = channels .into_iter() .map(ChannelEvent::created) - .filter(move |event| resume_at.map_or(true, |resume_at| event.sequence > resume_at)); + .filter(move |event| { + resume_at.map_or(true, |resume_at| Sequence::from(event) > resume_at) + }); let message_events = tx.message_events().replay(resume_at).await?; @@ -101,8 +103,8 @@ impl<'a> Events<'a> { .into_iter() .chain(message_events.into_iter()) .collect::>(); - replay_events.sort_by_key(|event| event.sequence); - let resume_live_at = replay_events.last().map(|event| event.sequence); + replay_events.sort_by_key(|event| Sequence::from(event)); + let resume_live_at = replay_events.last().map(Sequence::from); let replay = stream::iter(replay_events); @@ -124,7 +126,7 @@ impl<'a> Events<'a> { fn resume( resume_at: Option, ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready { - move |event| future::ready(resume_at < Some(event.sequence)) + move |event| future::ready(resume_at < Some(Sequence::from(event))) } } diff --git a/src/event/mod.rs b/src/event/mod.rs index 7ad3f9c..c982d3a 100644 --- a/src/event/mod.rs +++ b/src/event/mod.rs @@ -6,4 +6,13 @@ mod routes; mod sequence; pub mod types; +use crate::clock::DateTime; + pub use self::{routes::router, sequence::Sequence}; + +#[derive(Clone, Copy, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Instant { + pub at: DateTime, + #[serde(skip)] + pub sequence: Sequence, +} diff --git a/src/event/repo/message.rs b/src/event/repo/message.rs index f051fec..f29c8a4 100644 --- a/src/event/repo/message.rs +++ b/src/event/repo/message.rs @@ -3,7 +3,7 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ channel::{self, Channel}, clock::DateTime, - event::{types, Sequence}, + event::{types, Instant, Sequence}, login::{self, Login}, message::{self, Message}, }; @@ -25,8 +25,7 @@ impl<'c> Events<'c> { &mut self, sender: &Login, channel: &Channel, - sent_at: &DateTime, - sent_sequence: Sequence, + sent: &Instant, body: &str, ) -> Result { let id = message::Id::generate(); @@ -46,13 +45,15 @@ impl<'c> Events<'c> { id, channel.id, sender.id, - sent_at, - sent_sequence, + sent.at, + sent.sequence, body, ) .map(|row| types::ChannelEvent { - sequence: row.sent_sequence, - at: row.sent_at, + instant: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, data: types::MessageEvent { channel: channel.clone(), sender: sender.clone(), @@ -73,8 +74,7 @@ impl<'c> Events<'c> { &mut self, channel: &Channel, message: &message::Id, - deleted_at: &DateTime, - deleted_sequence: Sequence, + deleted: &Instant, ) -> Result { sqlx::query_scalar!( r#" @@ -88,8 +88,10 @@ impl<'c> Events<'c> { .await?; Ok(types::ChannelEvent { - sequence: deleted_sequence, - at: *deleted_at, + instant: Instant { + at: deleted.at, + sequence: deleted.sequence, + }, data: types::MessageDeletedEvent { channel: channel.clone(), message: message.clone(), @@ -122,8 +124,10 @@ impl<'c> Events<'c> { Channel { id: row.channel_id, name: row.channel_name, - created_at: row.channel_created_at, - created_sequence: row.channel_created_sequence, + created: Instant { + at: row.channel_created_at, + sequence: row.channel_created_sequence, + }, }, row.message, ) @@ -160,14 +164,18 @@ impl<'c> Events<'c> { resume_at, ) .map(|row| types::ChannelEvent { - sequence: row.sent_sequence, - at: row.sent_at, + instant: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, data: types::MessageEvent { channel: Channel { id: row.channel_id, name: row.channel_name, - created_at: row.channel_created_at, - created_sequence: row.channel_created_sequence, + created: Instant { + at: row.channel_created_at, + sequence: row.channel_created_sequence, + }, }, sender: Login { id: row.sender_id, diff --git a/src/event/repo/sequence.rs b/src/event/repo/sequence.rs index c985869..40d6a53 100644 --- a/src/event/repo/sequence.rs +++ b/src/event/repo/sequence.rs @@ -1,6 +1,9 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; -use crate::event::Sequence; +use crate::{ + clock::DateTime, + event::{Instant, Sequence}, +}; pub trait Provider { fn sequence(&mut self) -> Sequences; @@ -15,7 +18,7 @@ impl<'c> Provider for Transaction<'c, Sqlite> { pub struct Sequences<'t>(&'t mut SqliteConnection); impl<'c> Sequences<'c> { - pub async fn next(&mut self) -> Result { + pub async fn next(&mut self, at: &DateTime) -> Result { let next = sqlx::query_scalar!( r#" update event_sequence @@ -26,7 +29,10 @@ impl<'c> Sequences<'c> { .fetch_one(&mut *self.0) .await?; - Ok(next) + Ok(Instant { + at: *at, + sequence: next, + }) } pub async fn current(&mut self) -> Result { diff --git a/src/event/routes.rs b/src/event/routes.rs index 50ac435..c87bfb2 100644 --- a/src/event/routes.rs +++ b/src/event/routes.rs @@ -66,7 +66,7 @@ impl TryFrom for sse::Event { type Error = serde_json::Error; fn try_from(event: types::ChannelEvent) -> Result { - let id = serde_json::to_string(&event.sequence)?; + let id = serde_json::to_string(&Sequence::from(&event))?; let data = serde_json::to_string_pretty(&event)?; let event = Self::default().id(id).data(data); diff --git a/src/event/routes/test.rs b/src/event/routes/test.rs index d1ac3b4..68b55cc 100644 --- a/src/event/routes/test.rs +++ b/src/event/routes/test.rs @@ -6,7 +6,7 @@ use futures::{ }; use crate::{ - event::routes, + event::{routes, Sequence}, test::fixtures::{self, future::Immediately as _}, }; @@ -192,7 +192,7 @@ async fn resumes_from() { assert_eq!(initial_message, event); - event.sequence + Sequence::from(&event) }; // Resume after disconnect @@ -276,7 +276,7 @@ async fn serial_resume() { let event = events.last().expect("this vec is non-empty"); - event.sequence + Sequence::from(event) }; // Resume after disconnect @@ -312,7 +312,7 @@ async fn serial_resume() { let event = events.last().expect("this vec is non-empty"); - event.sequence + Sequence::from(event) }; // Resume after disconnect a second time diff --git a/src/event/types.rs b/src/event/types.rs index cd7dea6..2324dc1 100644 --- a/src/event/types.rs +++ b/src/event/types.rs @@ -1,16 +1,14 @@ use crate::{ channel::{self, Channel}, - clock::DateTime, - event::Sequence, + event::{Instant, Sequence}, login::Login, message::{self, Message}, }; #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct ChannelEvent { - #[serde(skip)] - pub sequence: Sequence, - pub at: DateTime, + #[serde(flatten)] + pub instant: Instant, #[serde(flatten)] pub data: ChannelEventData, } @@ -18,25 +16,15 @@ pub struct ChannelEvent { impl ChannelEvent { pub fn created(channel: Channel) -> Self { Self { - at: channel.created_at, - sequence: channel.created_sequence, + instant: channel.created, data: CreatedEvent { channel }.into(), } } - - pub fn channel_id(&self) -> &channel::Id { - match &self.data { - ChannelEventData::Created(event) => &event.channel.id, - ChannelEventData::Message(event) => &event.channel.id, - ChannelEventData::MessageDeleted(event) => &event.channel.id, - ChannelEventData::Deleted(event) => &event.channel, - } - } } impl<'c> From<&'c ChannelEvent> for Sequence { fn from(event: &'c ChannelEvent) -> Self { - event.sequence + event.instant.sequence } } -- cgit v1.2.3 From ec804134c33aedb001c426c5f42f43f53c47848f Mon Sep 17 00:00:00 2001 From: Owen Jacobson Date: Wed, 2 Oct 2024 12:25:36 -0400 Subject: Represent channels and messages using a split "History" and "Snapshot" model. This separates the code that figures out what happened to an entity from the code that represents it to a user, and makes it easier to compute a snapshot at a point in time (for things like bootstrap). It also makes the internal logic a bit easier to follow, since it's easier to tell whether you're working with a point in time or with the whole recorded history. This hefty. --- ...86018161718e2a6788413bffffb252de3e1959f341.json | 38 ++++ ...4abfbb9e06a5889460743202ca7956acabf006843e.json | 20 ++ ...2ba8e082c21c9e3d124e5340dec036edd341d94e0f.json | 26 +++ ...05a144db4901c302bbcd3e76da6c61742ac44345c9.json | 44 ----- ...0b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json | 32 +++ ...6854398e2151ba2dba10c03a9d2d93184141f1425c.json | 44 ----- ...4eedb8229360ba78f2607d25e7e2ee5db5c759a5a3.json | 62 ++++++ ...b8df33083c2da765dfda3023c78c25c06735670457.json | 38 ---- ...72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json | 74 ------- ...df5de8c7d9ad11bb09a0d0d243181d97c79d771071.json | 20 ++ ...6f1c2744bdb7a93c39ebcf15087c89bba6be71f7cb.json | 20 -- ...99e837106c799e84015425286b79f42e4001d8a4c7.json | 62 ++++++ ...76cdd9450fd1e8b4f2425cfda141d72fd94d3c39f9.json | 20 -- Cargo.lock | 10 + Cargo.toml | 1 + src/app.rs | 5 + src/broadcast.rs | 4 +- src/channel/app.rs | 35 +++- src/channel/event.rs | 48 +++++ src/channel/history.rs | 42 ++++ src/channel/mod.rs | 15 +- src/channel/repo.rs | 94 +++++---- src/channel/routes.rs | 11 +- src/channel/routes/test/on_create.rs | 6 +- src/channel/routes/test/on_send.rs | 13 +- src/channel/snapshot.rs | 38 ++++ src/event/app.rs | 122 +++--------- src/event/broadcaster.rs | 4 +- src/event/mod.rs | 73 ++++++- src/event/repo.rs | 50 +++++ src/event/repo/message.rs | 196 ------------------- src/event/repo/mod.rs | 4 - src/event/repo/sequence.rs | 50 ----- src/event/routes.rs | 14 +- src/event/routes/test.rs | 98 ++++++---- src/event/sequence.rs | 59 ++++++ src/event/types.rs | 85 -------- src/expire.rs | 2 +- src/message/app.rs | 88 +++++++++ src/message/event.rs | 50 +++++ src/message/history.rs | 43 +++++ src/message/mod.rs | 13 +- src/message/repo.rs | 214 +++++++++++++++++++++ src/message/snapshot.rs | 74 +++++++ src/test/fixtures/event.rs | 11 ++ src/test/fixtures/filter.rs | 10 +- src/test/fixtures/message.rs | 13 +- src/test/fixtures/mod.rs | 1 + src/token/app.rs | 4 +- 49 files changed, 1277 insertions(+), 823 deletions(-) create mode 100644 .sqlx/query-1654b05159c27f74cb333586018161718e2a6788413bffffb252de3e1959f341.json create mode 100644 .sqlx/query-33f9a143409e6f436ed6b64abfbb9e06a5889460743202ca7956acabf006843e.json create mode 100644 .sqlx/query-45449846ea98e892c6e58f2ba8e082c21c9e3d124e5340dec036edd341d94e0f.json delete mode 100644 .sqlx/query-4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9.json create mode 100644 .sqlx/query-4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json delete mode 100644 .sqlx/query-5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c.json create mode 100644 .sqlx/query-5c53579fa431b6e184faf94eedb8229360ba78f2607d25e7e2ee5db5c759a5a3.json delete mode 100644 .sqlx/query-74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457.json delete mode 100644 .sqlx/query-7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json create mode 100644 .sqlx/query-b7e05e2b2eb5484c954bcedf5de8c7d9ad11bb09a0d0d243181d97c79d771071.json delete mode 100644 .sqlx/query-d382215ac9e9d8d2c9b5eb6f1c2744bdb7a93c39ebcf15087c89bba6be71f7cb.json create mode 100644 .sqlx/query-e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7.json delete mode 100644 .sqlx/query-f5d5b3ec3554a80230e29676cdd9450fd1e8b4f2425cfda141d72fd94d3c39f9.json create mode 100644 src/channel/event.rs create mode 100644 src/channel/history.rs create mode 100644 src/channel/snapshot.rs create mode 100644 src/event/repo.rs delete mode 100644 src/event/repo/message.rs delete mode 100644 src/event/repo/mod.rs delete mode 100644 src/event/repo/sequence.rs delete mode 100644 src/event/types.rs create mode 100644 src/message/app.rs create mode 100644 src/message/event.rs create mode 100644 src/message/history.rs create mode 100644 src/message/repo.rs create mode 100644 src/message/snapshot.rs create mode 100644 src/test/fixtures/event.rs (limited to 'src/channel/mod.rs') diff --git a/.sqlx/query-1654b05159c27f74cb333586018161718e2a6788413bffffb252de3e1959f341.json b/.sqlx/query-1654b05159c27f74cb333586018161718e2a6788413bffffb252de3e1959f341.json new file mode 100644 index 0000000..cc716ed --- /dev/null +++ b/.sqlx/query-1654b05159c27f74cb333586018161718e2a6788413bffffb252de3e1959f341.json @@ -0,0 +1,38 @@ +{ + "db_name": "SQLite", + "query": "\n delete from channel\n where id = $1\n returning\n id as \"id: Id\",\n name,\n created_at as \"created_at: DateTime\",\n created_sequence as \"created_sequence: Sequence\"\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "created_at: DateTime", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "created_sequence: Sequence", + "ordinal": 3, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "1654b05159c27f74cb333586018161718e2a6788413bffffb252de3e1959f341" +} diff --git a/.sqlx/query-33f9a143409e6f436ed6b64abfbb9e06a5889460743202ca7956acabf006843e.json b/.sqlx/query-33f9a143409e6f436ed6b64abfbb9e06a5889460743202ca7956acabf006843e.json new file mode 100644 index 0000000..1480953 --- /dev/null +++ b/.sqlx/query-33f9a143409e6f436ed6b64abfbb9e06a5889460743202ca7956acabf006843e.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n delete from message\n where\n id = $1\n returning 1 as \"deleted: i64\"\n ", + "describe": { + "columns": [ + { + "name": "deleted: i64", + "ordinal": 0, + "type_info": "Null" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + null + ] + }, + "hash": "33f9a143409e6f436ed6b64abfbb9e06a5889460743202ca7956acabf006843e" +} diff --git a/.sqlx/query-45449846ea98e892c6e58f2ba8e082c21c9e3d124e5340dec036edd341d94e0f.json b/.sqlx/query-45449846ea98e892c6e58f2ba8e082c21c9e3d124e5340dec036edd341d94e0f.json new file mode 100644 index 0000000..2974cb0 --- /dev/null +++ b/.sqlx/query-45449846ea98e892c6e58f2ba8e082c21c9e3d124e5340dec036edd341d94e0f.json @@ -0,0 +1,26 @@ +{ + "db_name": "SQLite", + "query": "\n\t\t\t\tinsert into message\n\t\t\t\t\t(id, channel, sender, sent_at, sent_sequence, body)\n\t\t\t\tvalues ($1, $2, $3, $4, $5, $6)\n\t\t\t\treturning\n\t\t\t\t\tid as \"id: Id\",\n\t\t\t\t\tbody\n\t\t\t", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "body", + "ordinal": 1, + "type_info": "Text" + } + ], + "parameters": { + "Right": 6 + }, + "nullable": [ + false, + false + ] + }, + "hash": "45449846ea98e892c6e58f2ba8e082c21c9e3d124e5340dec036edd341d94e0f" +} diff --git a/.sqlx/query-4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9.json b/.sqlx/query-4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9.json deleted file mode 100644 index 494e1db..0000000 --- a/.sqlx/query-4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9.json +++ /dev/null @@ -1,44 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n\t\t\t\tinsert into message\n\t\t\t\t\t(id, channel, sender, sent_at, sent_sequence, body)\n\t\t\t\tvalues ($1, $2, $3, $4, $5, $6)\n\t\t\t\treturning\n\t\t\t\t\tid as \"id: message::Id\",\n\t\t\t\t\tsender as \"sender: login::Id\",\n sent_at as \"sent_at: DateTime\",\n sent_sequence as \"sent_sequence: Sequence\",\n\t\t\t\t\tbody\n\t\t\t", - "describe": { - "columns": [ - { - "name": "id: message::Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "sender: login::Id", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "sent_at: DateTime", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "sent_sequence: Sequence", - "ordinal": 3, - "type_info": "Integer" - }, - { - "name": "body", - "ordinal": 4, - "type_info": "Text" - } - ], - "parameters": { - "Right": 6 - }, - "nullable": [ - false, - false, - false, - false, - false - ] - }, - "hash": "4715007e2395ad30433b7405a144db4901c302bbcd3e76da6c61742ac44345c9" -} diff --git a/.sqlx/query-4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json b/.sqlx/query-4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json new file mode 100644 index 0000000..fb5f94b --- /dev/null +++ b/.sqlx/query-4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc.json @@ -0,0 +1,32 @@ +{ + "db_name": "SQLite", + "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n message.id as \"message: Id\"\n from message\n join channel on message.channel = channel.id\n where sent_at < $1\n ", + "describe": { + "columns": [ + { + "name": "channel_id: channel::Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "channel_name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "message: Id", + "ordinal": 2, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false + ] + }, + "hash": "4d4dce1b034f4a540f49490b1a8433a8ca334f1d666b104823e3fb0c08efb2cc" +} diff --git a/.sqlx/query-5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c.json b/.sqlx/query-5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c.json deleted file mode 100644 index 820b43f..0000000 --- a/.sqlx/query-5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c.json +++ /dev/null @@ -1,44 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n channel.created_at as \"channel_created_at: DateTime\",\n channel.created_sequence as \"channel_created_sequence: Sequence\",\n message.id as \"message: message::Id\"\n from message\n join channel on message.channel = channel.id\n join login as sender on message.sender = sender.id\n where sent_at < $1\n ", - "describe": { - "columns": [ - { - "name": "channel_id: channel::Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "channel_name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "channel_created_at: DateTime", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "channel_created_sequence: Sequence", - "ordinal": 3, - "type_info": "Integer" - }, - { - "name": "message: message::Id", - "ordinal": 4, - "type_info": "Text" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false, - false, - false - ] - }, - "hash": "5244f04bc270fc8d3cd4116854398e2151ba2dba10c03a9d2d93184141f1425c" -} diff --git a/.sqlx/query-5c53579fa431b6e184faf94eedb8229360ba78f2607d25e7e2ee5db5c759a5a3.json b/.sqlx/query-5c53579fa431b6e184faf94eedb8229360ba78f2607d25e7e2ee5db5c759a5a3.json new file mode 100644 index 0000000..4ca6786 --- /dev/null +++ b/.sqlx/query-5c53579fa431b6e184faf94eedb8229360ba78f2607d25e7e2ee5db5c759a5a3.json @@ -0,0 +1,62 @@ +{ + "db_name": "SQLite", + "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n sender.id as \"sender_id: login::Id\",\n sender.name as \"sender_name\",\n message.id as \"id: Id\",\n message.body,\n sent_at as \"sent_at: DateTime\",\n sent_sequence as \"sent_sequence: Sequence\"\n from message\n join channel on message.channel = channel.id\n join login as sender on message.sender = sender.id\n where coalesce(message.sent_sequence > $1, true)\n ", + "describe": { + "columns": [ + { + "name": "channel_id: channel::Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "channel_name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "sender_id: login::Id", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "sender_name", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "id: Id", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "body", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "sent_at: DateTime", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "sent_sequence: Sequence", + "ordinal": 7, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "5c53579fa431b6e184faf94eedb8229360ba78f2607d25e7e2ee5db5c759a5a3" +} diff --git a/.sqlx/query-74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457.json b/.sqlx/query-74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457.json deleted file mode 100644 index b34443f..0000000 --- a/.sqlx/query-74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n select\n channel.id as \"id: Id\",\n channel.name,\n channel.created_at as \"created_at: DateTime\",\n channel.created_sequence as \"created_sequence: Sequence\"\n from channel\n left join message\n where created_at < $1\n and message.id is null\n ", - "describe": { - "columns": [ - { - "name": "id: Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "name", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "created_at: DateTime", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "created_sequence: Sequence", - "ordinal": 3, - "type_info": "Integer" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false, - false - ] - }, - "hash": "74f0bad30dcec743d77309b8df33083c2da765dfda3023c78c25c06735670457" -} diff --git a/.sqlx/query-7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json b/.sqlx/query-7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json deleted file mode 100644 index f546438..0000000 --- a/.sqlx/query-7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c.json +++ /dev/null @@ -1,74 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n\t\t\t\tselect\n\t\t\t\t\tmessage.id as \"id: message::Id\",\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n channel.created_at as \"channel_created_at: DateTime\",\n channel.created_sequence as \"channel_created_sequence: Sequence\",\n\t\t\t\t\tsender.id as \"sender_id: login::Id\",\n\t\t\t\t\tsender.name as sender_name,\n message.sent_at as \"sent_at: DateTime\",\n message.sent_sequence as \"sent_sequence: Sequence\",\n message.body\n\t\t\t\tfrom message\n join channel on message.channel = channel.id\n\t\t\t\t\tjoin login as sender on message.sender = sender.id\n\t\t\t\twhere coalesce(message.sent_sequence > $1, true)\n\t\t\t\torder by sent_sequence asc\n\t\t\t", - "describe": { - "columns": [ - { - "name": "id: message::Id", - "ordinal": 0, - "type_info": "Text" - }, - { - "name": "channel_id: channel::Id", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "channel_name", - "ordinal": 2, - "type_info": "Text" - }, - { - "name": "channel_created_at: DateTime", - "ordinal": 3, - "type_info": "Text" - }, - { - "name": "channel_created_sequence: Sequence", - "ordinal": 4, - "type_info": "Integer" - }, - { - "name": "sender_id: login::Id", - "ordinal": 5, - "type_info": "Text" - }, - { - "name": "sender_name", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "sent_at: DateTime", - "ordinal": 7, - "type_info": "Text" - }, - { - "name": "sent_sequence: Sequence", - "ordinal": 8, - "type_info": "Integer" - }, - { - "name": "body", - "ordinal": 9, - "type_info": "Text" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - false, - false, - false - ] - }, - "hash": "7e816ede017bc2635c11ab72b18b7af92ac1f1faed9df41df90f57cb596cfe7c" -} diff --git a/.sqlx/query-b7e05e2b2eb5484c954bcedf5de8c7d9ad11bb09a0d0d243181d97c79d771071.json b/.sqlx/query-b7e05e2b2eb5484c954bcedf5de8c7d9ad11bb09a0d0d243181d97c79d771071.json new file mode 100644 index 0000000..b82727f --- /dev/null +++ b/.sqlx/query-b7e05e2b2eb5484c954bcedf5de8c7d9ad11bb09a0d0d243181d97c79d771071.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n select\n channel.id as \"id: Id\"\n from channel\n left join message\n where created_at < $1\n and message.id is null\n ", + "describe": { + "columns": [ + { + "name": "id: Id", + "ordinal": 0, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false + ] + }, + "hash": "b7e05e2b2eb5484c954bcedf5de8c7d9ad11bb09a0d0d243181d97c79d771071" +} diff --git a/.sqlx/query-d382215ac9e9d8d2c9b5eb6f1c2744bdb7a93c39ebcf15087c89bba6be71f7cb.json b/.sqlx/query-d382215ac9e9d8d2c9b5eb6f1c2744bdb7a93c39ebcf15087c89bba6be71f7cb.json deleted file mode 100644 index 1d448d4..0000000 --- a/.sqlx/query-d382215ac9e9d8d2c9b5eb6f1c2744bdb7a93c39ebcf15087c89bba6be71f7cb.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n delete from channel\n where id = $1\n returning 1 as \"row: i64\"\n ", - "describe": { - "columns": [ - { - "name": "row: i64", - "ordinal": 0, - "type_info": "Null" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - null - ] - }, - "hash": "d382215ac9e9d8d2c9b5eb6f1c2744bdb7a93c39ebcf15087c89bba6be71f7cb" -} diff --git a/.sqlx/query-e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7.json b/.sqlx/query-e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7.json new file mode 100644 index 0000000..288a657 --- /dev/null +++ b/.sqlx/query-e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7.json @@ -0,0 +1,62 @@ +{ + "db_name": "SQLite", + "query": "\n select\n channel.id as \"channel_id: channel::Id\",\n channel.name as \"channel_name\",\n sender.id as \"sender_id: login::Id\",\n sender.name as \"sender_name\",\n message.id as \"id: Id\",\n message.body,\n sent_at as \"sent_at: DateTime\",\n sent_sequence as \"sent_sequence: Sequence\"\n from message\n join channel on message.channel = channel.id\n join login as sender on message.sender = sender.id\n where message.id = $1\n and message.channel = $2\n ", + "describe": { + "columns": [ + { + "name": "channel_id: channel::Id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "channel_name", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "sender_id: login::Id", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "sender_name", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "id: Id", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "body", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "sent_at: DateTime", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "sent_sequence: Sequence", + "ordinal": 7, + "type_info": "Integer" + } + ], + "parameters": { + "Right": 2 + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "e93702ad922c7ce802499e99e837106c799e84015425286b79f42e4001d8a4c7" +} diff --git a/.sqlx/query-f5d5b3ec3554a80230e29676cdd9450fd1e8b4f2425cfda141d72fd94d3c39f9.json b/.sqlx/query-f5d5b3ec3554a80230e29676cdd9450fd1e8b4f2425cfda141d72fd94d3c39f9.json deleted file mode 100644 index 7b1d2d8..0000000 --- a/.sqlx/query-f5d5b3ec3554a80230e29676cdd9450fd1e8b4f2425cfda141d72fd94d3c39f9.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "db_name": "SQLite", - "query": "\n delete from message\n where id = $1\n returning 1 as \"row: i64\"\n ", - "describe": { - "columns": [ - { - "name": "row: i64", - "ordinal": 0, - "type_info": "Null" - } - ], - "parameters": { - "Right": 1 - }, - "nullable": [ - null - ] - }, - "hash": "f5d5b3ec3554a80230e29676cdd9450fd1e8b4f2425cfda141d72fd94d3c39f9" -} diff --git a/Cargo.lock b/Cargo.lock index aa2120a..b1f0582 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -787,6 +787,7 @@ dependencies = [ "faker_rand", "futures", "headers", + "itertools", "password-hash", "rand", "rand_core", @@ -956,6 +957,15 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" diff --git a/Cargo.toml b/Cargo.toml index cb46b41..2b2e774 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ chrono = { version = "0.4.38", features = ["serde"] } clap = { version = "4.5.18", features = ["derive", "env"] } futures = "0.3.30" headers = "0.4.0" +itertools = "0.13.0" password-hash = { version = "0.5.0", features = ["std"] } rand = "0.8.5" rand_core = { version = "0.6.4", features = ["getrandom"] } diff --git a/src/app.rs b/src/app.rs index 5542e5f..186e5f8 100644 --- a/src/app.rs +++ b/src/app.rs @@ -4,6 +4,7 @@ use crate::{ channel::app::Channels, event::{app::Events, broadcaster::Broadcaster as EventBroadcaster}, login::app::Logins, + message::app::Messages, token::{app::Tokens, broadcaster::Broadcaster as TokenBroadcaster}, }; @@ -35,6 +36,10 @@ impl App { Logins::new(&self.db) } + pub const fn messages(&self) -> Messages { + Messages::new(&self.db, &self.events) + } + pub const fn tokens(&self) -> Tokens { Tokens::new(&self.db, &self.tokens) } diff --git a/src/broadcast.rs b/src/broadcast.rs index 083a301..bedc263 100644 --- a/src/broadcast.rs +++ b/src/broadcast.rs @@ -32,7 +32,7 @@ where { // panic: if ``message.channel.id`` has not been previously registered, // and was not part of the initial set of channels. - pub fn broadcast(&self, message: &M) { + pub fn broadcast(&self, message: impl Into) { let tx = self.sender(); // Per the Tokio docs, the returned error is only used to indicate that @@ -42,7 +42,7 @@ where // // The successful return value, which includes the number of active // receivers, also isn't that interesting to us. - let _ = tx.send(message.clone()); + let _ = tx.send(message.into()); } // panic: if ``channel`` has not been previously registered, and was not diff --git a/src/channel/app.rs b/src/channel/app.rs index b7e3a10..6ce826b 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -1,10 +1,11 @@ use chrono::TimeDelta; +use itertools::Itertools; use sqlx::sqlite::SqlitePool; use crate::{ channel::{repo::Provider as _, Channel}, clock::DateTime, - event::{broadcaster::Broadcaster, repo::Provider as _, types::ChannelEvent, Sequence}, + event::{broadcaster::Broadcaster, repo::Provider as _, Sequence}, }; pub struct Channels<'a> { @@ -27,10 +28,11 @@ impl<'a> Channels<'a> { .map_err(|err| CreateError::from_duplicate_name(err, name))?; tx.commit().await?; - self.events - .broadcast(&ChannelEvent::created(channel.clone())); + for event in channel.events() { + self.events.broadcast(event); + } - Ok(channel) + Ok(channel.snapshot()) } pub async fn all(&self, resume_point: Option) -> Result, InternalError> { @@ -38,6 +40,16 @@ impl<'a> Channels<'a> { let channels = tx.channels().all(resume_point).await?; tx.commit().await?; + let channels = channels + .into_iter() + .filter_map(|channel| { + channel + .events() + .filter(Sequence::up_to(resume_point)) + .collect() + }) + .collect(); + Ok(channels) } @@ -51,14 +63,21 @@ impl<'a> Channels<'a> { let mut events = Vec::with_capacity(expired.len()); for channel in expired { let deleted = tx.sequence().next(relative_to).await?; - let event = tx.channels().delete(&channel, &deleted).await?; - events.push(event); + let channel = tx.channels().delete(&channel, &deleted).await?; + events.push( + channel + .events() + .filter(Sequence::start_from(deleted.sequence)), + ); } tx.commit().await?; - for event in events { - self.events.broadcast(&event); + for event in events + .into_iter() + .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) + { + self.events.broadcast(event); } Ok(()) diff --git a/src/channel/event.rs b/src/channel/event.rs new file mode 100644 index 0000000..9c54174 --- /dev/null +++ b/src/channel/event.rs @@ -0,0 +1,48 @@ +use super::Channel; +use crate::{ + channel, + event::{Instant, Sequenced}, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Event { + #[serde(flatten)] + pub instant: Instant, + #[serde(flatten)] + pub kind: Kind, +} + +impl Sequenced for Event { + fn instant(&self) -> Instant { + self.instant + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Kind { + Created(Created), + Deleted(Deleted), +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Created { + pub channel: Channel, +} + +impl From for Kind { + fn from(event: Created) -> Self { + Self::Created(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Deleted { + pub channel: channel::Id, +} + +impl From for Kind { + fn from(event: Deleted) -> Self { + Self::Deleted(event) + } +} diff --git a/src/channel/history.rs b/src/channel/history.rs new file mode 100644 index 0000000..3cc7d9d --- /dev/null +++ b/src/channel/history.rs @@ -0,0 +1,42 @@ +use super::{ + event::{Created, Deleted, Event}, + Channel, +}; +use crate::event::Instant; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct History { + pub channel: Channel, + pub created: Instant, + pub deleted: Option, +} + +impl History { + fn created(&self) -> Event { + Event { + instant: self.created, + kind: Created { + channel: self.channel.clone(), + } + .into(), + } + } + + fn deleted(&self) -> Option { + self.deleted.map(|instant| Event { + instant, + kind: Deleted { + channel: self.channel.id.clone(), + } + .into(), + }) + } + + pub fn events(&self) -> impl Iterator { + [self.created()].into_iter().chain(self.deleted()) + } + + pub fn snapshot(&self) -> Channel { + self.channel.clone() + } +} diff --git a/src/channel/mod.rs b/src/channel/mod.rs index 4baa7e3..eb8200b 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -1,16 +1,9 @@ -use crate::event::Instant; - pub mod app; +pub mod event; +mod history; mod id; pub mod repo; mod routes; +mod snapshot; -pub use self::{id::Id, routes::router}; - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Channel { - pub id: Id, - pub name: String, - #[serde(skip)] - pub created: Instant, -} +pub use self::{event::Event, history::History, id::Id, routes::router, snapshot::Channel}; diff --git a/src/channel/repo.rs b/src/channel/repo.rs index c000b56..8bb761b 100644 --- a/src/channel/repo.rs +++ b/src/channel/repo.rs @@ -1,9 +1,9 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ - channel::{Channel, Id}, + channel::{Channel, History, Id}, clock::DateTime, - event::{types, Instant, Sequence}, + event::{Instant, Sequence}, }; pub trait Provider { @@ -19,7 +19,7 @@ impl<'c> Provider for Transaction<'c, Sqlite> { pub struct Channels<'t>(&'t mut SqliteConnection); impl<'c> Channels<'c> { - pub async fn create(&mut self, name: &str, created: &Instant) -> Result { + pub async fn create(&mut self, name: &str, created: &Instant) -> Result { let id = Id::generate(); let channel = sqlx::query!( r#" @@ -37,13 +37,16 @@ impl<'c> Channels<'c> { created.at, created.sequence, ) - .map(|row| Channel { - id: row.id, - name: row.name, + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, created: Instant { at: row.created_at, sequence: row.created_sequence, }, + deleted: None, }) .fetch_one(&mut *self.0) .await?; @@ -51,7 +54,7 @@ impl<'c> Channels<'c> { Ok(channel) } - pub async fn by_id(&mut self, channel: &Id) -> Result { + pub async fn by_id(&mut self, channel: &Id) -> Result { let channel = sqlx::query!( r#" select @@ -64,13 +67,16 @@ impl<'c> Channels<'c> { "#, channel, ) - .map(|row| Channel { - id: row.id, - name: row.name, + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, created: Instant { at: row.created_at, sequence: row.created_sequence, }, + deleted: None, }) .fetch_one(&mut *self.0) .await?; @@ -81,7 +87,7 @@ impl<'c> Channels<'c> { pub async fn all( &mut self, resume_point: Option, - ) -> Result, sqlx::Error> { + ) -> Result, sqlx::Error> { let channels = sqlx::query!( r#" select @@ -95,13 +101,16 @@ impl<'c> Channels<'c> { "#, resume_point, ) - .map(|row| Channel { - id: row.id, - name: row.name, + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, created: Instant { at: row.created_at, sequence: row.created_sequence, }, + deleted: None, }) .fetch_all(&mut *self.0) .await?; @@ -112,7 +121,7 @@ impl<'c> Channels<'c> { pub async fn replay( &mut self, resume_at: Option, - ) -> Result, sqlx::Error> { + ) -> Result, sqlx::Error> { let channels = sqlx::query!( r#" select @@ -125,13 +134,16 @@ impl<'c> Channels<'c> { "#, resume_at, ) - .map(|row| Channel { - id: row.id, - name: row.name, + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, created: Instant { at: row.created_at, sequence: row.created_sequence, }, + deleted: None, }) .fetch_all(&mut *self.0) .await?; @@ -141,35 +153,43 @@ impl<'c> Channels<'c> { pub async fn delete( &mut self, - channel: &Channel, + channel: &Id, deleted: &Instant, - ) -> Result { - let channel = channel.id.clone(); - sqlx::query_scalar!( + ) -> Result { + let channel = sqlx::query!( r#" delete from channel where id = $1 - returning 1 as "row: i64" + returning + id as "id: Id", + name, + created_at as "created_at: DateTime", + created_sequence as "created_sequence: Sequence" "#, channel, ) + .map(|row| History { + channel: Channel { + id: row.id, + name: row.name, + }, + created: Instant { + at: row.created_at, + sequence: row.created_sequence, + }, + deleted: Some(*deleted), + }) .fetch_one(&mut *self.0) .await?; - Ok(types::ChannelEvent { - instant: *deleted, - data: types::DeletedEvent { channel }.into(), - }) + Ok(channel) } - pub async fn expired(&mut self, expired_at: &DateTime) -> Result, sqlx::Error> { - let channels = sqlx::query!( + pub async fn expired(&mut self, expired_at: &DateTime) -> Result, sqlx::Error> { + let channels = sqlx::query_scalar!( r#" select - channel.id as "id: Id", - channel.name, - channel.created_at as "created_at: DateTime", - channel.created_sequence as "created_sequence: Sequence" + channel.id as "id: Id" from channel left join message where created_at < $1 @@ -177,14 +197,6 @@ impl<'c> Channels<'c> { "#, expired_at, ) - .map(|row| Channel { - id: row.id, - name: row.name, - created: Instant { - at: row.created_at, - sequence: row.created_sequence, - }, - }) .fetch_all(&mut *self.0) .await?; diff --git a/src/channel/routes.rs b/src/channel/routes.rs index 5d8b61e..5bb1ee9 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -13,8 +13,9 @@ use crate::{ channel::{self, Channel}, clock::RequestedAt, error::Internal, - event::{app::EventsError, Sequence}, + event::Sequence, login::Login, + message::app::Error as MessageError, }; #[cfg(test)] @@ -99,8 +100,8 @@ async fn on_send( login: Login, Json(request): Json, ) -> Result { - app.events() - .send(&login, &channel, &request.message, &sent_at) + app.messages() + .send(&channel, &login, &sent_at, &request.message) .await // Could impl `From` here, but it's more code and this is used once. .map_err(ErrorResponse)?; @@ -109,13 +110,13 @@ async fn on_send( } #[derive(Debug)] -struct ErrorResponse(EventsError); +struct ErrorResponse(MessageError); impl IntoResponse for ErrorResponse { fn into_response(self) -> Response { let Self(error) = self; match error { - not_found @ EventsError::ChannelNotFound(_) => { + not_found @ MessageError::ChannelNotFound(_) => { (StatusCode::NOT_FOUND, not_found.to_string()).into_response() } other => Internal::from(other).into_response(), diff --git a/src/channel/routes/test/on_create.rs b/src/channel/routes/test/on_create.rs index 9988932..5733c9e 100644 --- a/src/channel/routes/test/on_create.rs +++ b/src/channel/routes/test/on_create.rs @@ -3,7 +3,7 @@ use futures::stream::StreamExt as _; use crate::{ channel::{app, routes}, - event::types, + event, test::fixtures::{self, future::Immediately as _}, }; @@ -50,8 +50,8 @@ async fn new_channel() { .expect("creation event published"); assert!(matches!( - event.data, - types::ChannelEventData::Created(event) + event.kind, + event::Kind::ChannelCreated(event) if event.channel == response_channel )); } diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index 33ec3b7..1027b29 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -4,7 +4,8 @@ use futures::stream::StreamExt; use crate::{ channel, channel::routes, - event::{app, types}, + event, + message::app, test::fixtures::{self, future::Immediately as _}, }; @@ -54,10 +55,10 @@ async fn messages_in_order() { for ((sent_at, message), event) in requests.into_iter().zip(events) { assert_eq!(*sent_at, event.instant.at); assert!(matches!( - event.data, - types::ChannelEventData::Message(event_message) - if event_message.sender == sender - && event_message.message.body == message + event.kind, + event::Kind::MessageSent(event) + if event.message.sender == sender + && event.message.body == message )); } } @@ -90,6 +91,6 @@ async fn nonexistent_channel() { assert!(matches!( error, - app::EventsError::ChannelNotFound(error_channel) if channel == error_channel + app::Error::ChannelNotFound(error_channel) if channel == error_channel )); } diff --git a/src/channel/snapshot.rs b/src/channel/snapshot.rs new file mode 100644 index 0000000..6462f25 --- /dev/null +++ b/src/channel/snapshot.rs @@ -0,0 +1,38 @@ +use super::{ + event::{Created, Event, Kind}, + Id, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Channel { + pub id: Id, + pub name: String, +} + +impl Channel { + fn apply(state: Option, event: Event) -> Option { + match (state, event.kind) { + (None, Kind::Created(event)) => Some(event.into()), + (Some(channel), Kind::Deleted(event)) if channel.id == event.channel => None, + (state, event) => panic!("invalid channel event {event:#?} for state {state:#?}"), + } + } +} + +impl FromIterator for Option { + fn from_iter>(events: I) -> Self { + events.into_iter().fold(None, Channel::apply) + } +} + +impl From<&Created> for Channel { + fn from(event: &Created) -> Self { + event.channel.clone() + } +} + +impl From for Channel { + fn from(event: Created) -> Self { + event.channel + } +} diff --git a/src/event/app.rs b/src/event/app.rs index 5e9e79a..e58bea9 100644 --- a/src/event/app.rs +++ b/src/event/app.rs @@ -1,22 +1,15 @@ -use chrono::TimeDelta; use futures::{ future, stream::{self, StreamExt as _}, Stream, }; +use itertools::Itertools as _; use sqlx::sqlite::SqlitePool; -use super::{ - broadcaster::Broadcaster, - repo::message::Provider as _, - types::{self, ChannelEvent}, -}; +use super::{broadcaster::Broadcaster, Event, Sequence, Sequenced}; use crate::{ channel::{self, repo::Provider as _}, - clock::DateTime, - db::NotFound as _, - event::{repo::Provider as _, Sequence}, - login::Login, + message::{self, repo::Provider as _}, }; pub struct Events<'a> { @@ -29,111 +22,52 @@ impl<'a> Events<'a> { Self { db, events } } - pub async fn send( - &self, - login: &Login, - channel: &channel::Id, - body: &str, - sent_at: &DateTime, - ) -> Result { - let mut tx = self.db.begin().await?; - let channel = tx - .channels() - .by_id(channel) - .await - .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; - let sent = tx.sequence().next(sent_at).await?; - let event = tx - .message_events() - .create(login, &channel, &sent, body) - .await?; - tx.commit().await?; - - self.events.broadcast(&event); - Ok(event) - } - - pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { - // Somewhat arbitrarily, expire after 90 days. - let expire_at = relative_to.to_owned() - TimeDelta::days(90); - - let mut tx = self.db.begin().await?; - let expired = tx.message_events().expired(&expire_at).await?; - - let mut events = Vec::with_capacity(expired.len()); - for (channel, message) in expired { - let deleted = tx.sequence().next(relative_to).await?; - let event = tx - .message_events() - .delete(&channel, &message, &deleted) - .await?; - events.push(event); - } - - tx.commit().await?; - - for event in events { - self.events.broadcast(&event); - } - - Ok(()) - } - pub async fn subscribe( &self, resume_at: Option, - ) -> Result + std::fmt::Debug, sqlx::Error> { + ) -> Result + std::fmt::Debug, sqlx::Error> { // Subscribe before retrieving, to catch messages broadcast while we're // querying the DB. We'll prune out duplicates later. let live_messages = self.events.subscribe(); let mut tx = self.db.begin().await?; - let channels = tx.channels().replay(resume_at).await?; + let channels = tx.channels().replay(resume_at).await?; let channel_events = channels - .into_iter() - .map(ChannelEvent::created) - .filter(move |event| { - resume_at.map_or(true, |resume_at| Sequence::from(event) > resume_at) - }); - - let message_events = tx.message_events().replay(resume_at).await?; - - let mut replay_events = channel_events - .into_iter() - .chain(message_events.into_iter()) + .iter() + .map(channel::History::events) + .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) + .filter(Sequence::after(resume_at)) + .map(Event::from); + + let messages = tx.messages().replay(resume_at).await?; + let message_events = messages + .iter() + .map(message::History::events) + .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) + .filter(Sequence::after(resume_at)) + .map(Event::from); + + let replay_events = channel_events + .merge_by(message_events, |a, b| { + a.instant.sequence < b.instant.sequence + }) .collect::>(); - replay_events.sort_by_key(|event| Sequence::from(event)); - let resume_live_at = replay_events.last().map(Sequence::from); + let resume_live_at = replay_events.last().map(Sequenced::sequence); let replay = stream::iter(replay_events); - // no skip_expired or resume transforms for stored_messages, as it's - // constructed not to contain messages meeting either criterion. - // - // * skip_expired is redundant with the `tx.broadcasts().expire(…)` call; - // * resume is redundant with the resume_at argument to - // `tx.broadcasts().replay(…)`. let live_messages = live_messages // Filtering on the broadcast resume point filters out messages // before resume_at, and filters out messages duplicated from - // stored_messages. + // `replay_events`. .filter(Self::resume(resume_live_at)); Ok(replay.chain(live_messages)) } - fn resume( - resume_at: Option, - ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready { - move |event| future::ready(resume_at < Some(Sequence::from(event))) + fn resume(resume_at: Option) -> impl for<'m> FnMut(&'m Event) -> future::Ready { + let filter = Sequence::after(resume_at); + move |event| future::ready(filter(event)) } } - -#[derive(Debug, thiserror::Error)] -pub enum EventsError { - #[error("channel {0} not found")] - ChannelNotFound(channel::Id), - #[error(transparent)] - DatabaseError(#[from] sqlx::Error), -} diff --git a/src/event/broadcaster.rs b/src/event/broadcaster.rs index 92f631f..de2513a 100644 --- a/src/event/broadcaster.rs +++ b/src/event/broadcaster.rs @@ -1,3 +1,3 @@ -use crate::{broadcast, event::types}; +use crate::broadcast; -pub type Broadcaster = broadcast::Broadcaster; +pub type Broadcaster = broadcast::Broadcaster; diff --git a/src/event/mod.rs b/src/event/mod.rs index c982d3a..1503b77 100644 --- a/src/event/mod.rs +++ b/src/event/mod.rs @@ -1,18 +1,75 @@ +use crate::{channel, message}; + pub mod app; pub mod broadcaster; mod extract; pub mod repo; mod routes; mod sequence; -pub mod types; -use crate::clock::DateTime; +pub use self::{ + routes::router, + sequence::{Instant, Sequence, Sequenced}, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Event { + #[serde(flatten)] + pub instant: Instant, + #[serde(flatten)] + pub kind: Kind, +} + +impl Sequenced for Event { + fn instant(&self) -> Instant { + self.instant + } +} + +impl From for Event { + fn from(event: channel::Event) -> Self { + Self { + instant: event.instant, + kind: event.kind.into(), + } + } +} + +impl From for Event { + fn from(event: message::Event) -> Self { + Self { + instant: event.instant, + kind: event.kind.into(), + } + } +} -pub use self::{routes::router, sequence::Sequence}; +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Kind { + #[serde(rename = "created")] + ChannelCreated(channel::event::Created), + #[serde(rename = "message")] + MessageSent(message::event::Sent), + MessageDeleted(message::event::Deleted), + #[serde(rename = "deleted")] + ChannelDeleted(channel::event::Deleted), +} + +impl From for Kind { + fn from(kind: channel::event::Kind) -> Self { + match kind { + channel::event::Kind::Created(created) => Self::ChannelCreated(created), + channel::event::Kind::Deleted(deleted) => Self::ChannelDeleted(deleted), + } + } +} -#[derive(Clone, Copy, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Instant { - pub at: DateTime, - #[serde(skip)] - pub sequence: Sequence, +impl From for Kind { + fn from(kind: message::event::Kind) -> Self { + match kind { + message::event::Kind::Sent(created) => Self::MessageSent(created), + message::event::Kind::Deleted(deleted) => Self::MessageDeleted(deleted), + } + } } diff --git a/src/event/repo.rs b/src/event/repo.rs new file mode 100644 index 0000000..40d6a53 --- /dev/null +++ b/src/event/repo.rs @@ -0,0 +1,50 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::{ + clock::DateTime, + event::{Instant, Sequence}, +}; + +pub trait Provider { + fn sequence(&mut self) -> Sequences; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn sequence(&mut self) -> Sequences { + Sequences(self) + } +} + +pub struct Sequences<'t>(&'t mut SqliteConnection); + +impl<'c> Sequences<'c> { + pub async fn next(&mut self, at: &DateTime) -> Result { + let next = sqlx::query_scalar!( + r#" + update event_sequence + set last_value = last_value + 1 + returning last_value as "next_value: Sequence" + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(Instant { + at: *at, + sequence: next, + }) + } + + pub async fn current(&mut self) -> Result { + let next = sqlx::query_scalar!( + r#" + select last_value as "last_value: Sequence" + from event_sequence + "#, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } +} diff --git a/src/event/repo/message.rs b/src/event/repo/message.rs deleted file mode 100644 index f29c8a4..0000000 --- a/src/event/repo/message.rs +++ /dev/null @@ -1,196 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{ - channel::{self, Channel}, - clock::DateTime, - event::{types, Instant, Sequence}, - login::{self, Login}, - message::{self, Message}, -}; - -pub trait Provider { - fn message_events(&mut self) -> Events; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn message_events(&mut self) -> Events { - Events(self) - } -} - -pub struct Events<'t>(&'t mut SqliteConnection); - -impl<'c> Events<'c> { - pub async fn create( - &mut self, - sender: &Login, - channel: &Channel, - sent: &Instant, - body: &str, - ) -> Result { - let id = message::Id::generate(); - - let message = sqlx::query!( - r#" - insert into message - (id, channel, sender, sent_at, sent_sequence, body) - values ($1, $2, $3, $4, $5, $6) - returning - id as "id: message::Id", - sender as "sender: login::Id", - sent_at as "sent_at: DateTime", - sent_sequence as "sent_sequence: Sequence", - body - "#, - id, - channel.id, - sender.id, - sent.at, - sent.sequence, - body, - ) - .map(|row| types::ChannelEvent { - instant: Instant { - at: row.sent_at, - sequence: row.sent_sequence, - }, - data: types::MessageEvent { - channel: channel.clone(), - sender: sender.clone(), - message: Message { - id: row.id, - body: row.body, - }, - } - .into(), - }) - .fetch_one(&mut *self.0) - .await?; - - Ok(message) - } - - pub async fn delete( - &mut self, - channel: &Channel, - message: &message::Id, - deleted: &Instant, - ) -> Result { - sqlx::query_scalar!( - r#" - delete from message - where id = $1 - returning 1 as "row: i64" - "#, - message, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(types::ChannelEvent { - instant: Instant { - at: deleted.at, - sequence: deleted.sequence, - }, - data: types::MessageDeletedEvent { - channel: channel.clone(), - message: message.clone(), - } - .into(), - }) - } - - pub async fn expired( - &mut self, - expire_at: &DateTime, - ) -> Result, sqlx::Error> { - let messages = sqlx::query!( - r#" - select - channel.id as "channel_id: channel::Id", - channel.name as "channel_name", - channel.created_at as "channel_created_at: DateTime", - channel.created_sequence as "channel_created_sequence: Sequence", - message.id as "message: message::Id" - from message - join channel on message.channel = channel.id - join login as sender on message.sender = sender.id - where sent_at < $1 - "#, - expire_at, - ) - .map(|row| { - ( - Channel { - id: row.channel_id, - name: row.channel_name, - created: Instant { - at: row.channel_created_at, - sequence: row.channel_created_sequence, - }, - }, - row.message, - ) - }) - .fetch_all(&mut *self.0) - .await?; - - Ok(messages) - } - - pub async fn replay( - &mut self, - resume_at: Option, - ) -> Result, sqlx::Error> { - let events = sqlx::query!( - r#" - select - message.id as "id: message::Id", - channel.id as "channel_id: channel::Id", - channel.name as "channel_name", - channel.created_at as "channel_created_at: DateTime", - channel.created_sequence as "channel_created_sequence: Sequence", - sender.id as "sender_id: login::Id", - sender.name as sender_name, - message.sent_at as "sent_at: DateTime", - message.sent_sequence as "sent_sequence: Sequence", - message.body - from message - join channel on message.channel = channel.id - join login as sender on message.sender = sender.id - where coalesce(message.sent_sequence > $1, true) - order by sent_sequence asc - "#, - resume_at, - ) - .map(|row| types::ChannelEvent { - instant: Instant { - at: row.sent_at, - sequence: row.sent_sequence, - }, - data: types::MessageEvent { - channel: Channel { - id: row.channel_id, - name: row.channel_name, - created: Instant { - at: row.channel_created_at, - sequence: row.channel_created_sequence, - }, - }, - sender: Login { - id: row.sender_id, - name: row.sender_name, - }, - message: Message { - id: row.id, - body: row.body, - }, - } - .into(), - }) - .fetch_all(&mut *self.0) - .await?; - - Ok(events) - } -} diff --git a/src/event/repo/mod.rs b/src/event/repo/mod.rs deleted file mode 100644 index cee840c..0000000 --- a/src/event/repo/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod message; -mod sequence; - -pub use self::sequence::Provider; diff --git a/src/event/repo/sequence.rs b/src/event/repo/sequence.rs deleted file mode 100644 index 40d6a53..0000000 --- a/src/event/repo/sequence.rs +++ /dev/null @@ -1,50 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{ - clock::DateTime, - event::{Instant, Sequence}, -}; - -pub trait Provider { - fn sequence(&mut self) -> Sequences; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn sequence(&mut self) -> Sequences { - Sequences(self) - } -} - -pub struct Sequences<'t>(&'t mut SqliteConnection); - -impl<'c> Sequences<'c> { - pub async fn next(&mut self, at: &DateTime) -> Result { - let next = sqlx::query_scalar!( - r#" - update event_sequence - set last_value = last_value + 1 - returning last_value as "next_value: Sequence" - "#, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(Instant { - at: *at, - sequence: next, - }) - } - - pub async fn current(&mut self) -> Result { - let next = sqlx::query_scalar!( - r#" - select last_value as "last_value: Sequence" - from event_sequence - "#, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(next) - } -} diff --git a/src/event/routes.rs b/src/event/routes.rs index c87bfb2..5b9c7e3 100644 --- a/src/event/routes.rs +++ b/src/event/routes.rs @@ -10,11 +10,11 @@ use axum::{ use axum_extra::extract::Query; use futures::stream::{Stream, StreamExt as _}; -use super::{extract::LastEventId, types}; +use super::{extract::LastEventId, Event}; use crate::{ app::App, error::{Internal, Unauthorized}, - event::Sequence, + event::{Sequence, Sequenced as _}, token::{app::ValidateError, extract::Identity}, }; @@ -35,7 +35,7 @@ async fn events( identity: Identity, last_event_id: Option>, Query(query): Query, -) -> Result + std::fmt::Debug>, EventsError> { +) -> Result + std::fmt::Debug>, EventsError> { let resume_at = last_event_id .map(LastEventId::into_inner) .or(query.resume_point); @@ -51,7 +51,7 @@ struct Events(S); impl IntoResponse for Events where - S: Stream + Send + 'static, + S: Stream + Send + 'static, { fn into_response(self) -> Response { let Self(stream) = self; @@ -62,11 +62,11 @@ where } } -impl TryFrom for sse::Event { +impl TryFrom for sse::Event { type Error = serde_json::Error; - fn try_from(event: types::ChannelEvent) -> Result { - let id = serde_json::to_string(&Sequence::from(&event))?; + fn try_from(event: Event) -> Result { + let id = serde_json::to_string(&event.sequence())?; let data = serde_json::to_string_pretty(&event)?; let event = Self::default().id(id).data(data); diff --git a/src/event/routes/test.rs b/src/event/routes/test.rs index 68b55cc..ba9953e 100644 --- a/src/event/routes/test.rs +++ b/src/event/routes/test.rs @@ -6,7 +6,7 @@ use futures::{ }; use crate::{ - event::{routes, Sequence}, + event::{routes, Sequenced as _}, test::fixtures::{self, future::Immediately as _}, }; @@ -17,7 +17,7 @@ async fn includes_historical_message() { let app = fixtures::scratch_app().await; let sender = fixtures::login::create(&app).await; let channel = fixtures::channel::create(&app, &fixtures::now()).await; - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; // Call the endpoint @@ -36,7 +36,7 @@ async fn includes_historical_message() { .await .expect("delivered stored message"); - assert_eq!(message, event); + assert!(fixtures::event::message_sent(&event, &message)); } #[tokio::test] @@ -58,7 +58,7 @@ async fn includes_live_message() { // Verify the semantics let sender = fixtures::login::create(&app).await; - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + let message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; let event = events .filter(fixtures::filter::messages()) @@ -67,7 +67,7 @@ async fn includes_live_message() { .await .expect("delivered live message"); - assert_eq!(message, event); + assert!(fixtures::event::message_sent(&event, &message)); } #[tokio::test] @@ -87,7 +87,7 @@ async fn includes_multiple_channels() { let app = app.clone(); let sender = sender.clone(); let channel = channel.clone(); - async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await } + async move { fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await } }) .collect::>() .await; @@ -110,7 +110,9 @@ async fn includes_multiple_channels() { .await; for message in &messages { - assert!(events.iter().any(|event| { event == message })); + assert!(events + .iter() + .any(|event| fixtures::event::message_sent(event, message))); } } @@ -123,9 +125,9 @@ async fn sequential_messages() { let sender = fixtures::login::create(&app).await; let messages = vec![ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, ]; // Call the endpoint @@ -138,7 +140,13 @@ async fn sequential_messages() { // Verify the structure of the response. - let mut events = events.filter(|event| future::ready(messages.contains(event))); + let mut events = events.filter(|event| { + future::ready( + messages + .iter() + .any(|message| fixtures::event::message_sent(event, message)), + ) + }); // Verify delivery in order for message in &messages { @@ -148,7 +156,7 @@ async fn sequential_messages() { .await .expect("undelivered messages remaining"); - assert_eq!(message, &event); + assert!(fixtures::event::message_sent(&event, message)); } } @@ -160,11 +168,11 @@ async fn resumes_from() { let channel = fixtures::channel::create(&app, &fixtures::now()).await; let sender = fixtures::login::create(&app).await; - let initial_message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + let initial_message = fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await; let later_messages = vec![ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, ]; // Call the endpoint @@ -190,9 +198,9 @@ async fn resumes_from() { .await .expect("delivered events"); - assert_eq!(initial_message, event); + assert!(fixtures::event::message_sent(&event, &initial_message)); - Sequence::from(&event) + event.sequence() }; // Resume after disconnect @@ -214,7 +222,9 @@ async fn resumes_from() { .await; for message in &later_messages { - assert!(events.iter().any(|event| event == message)); + assert!(events + .iter() + .any(|event| fixtures::event::message_sent(event, message))); } } @@ -249,8 +259,8 @@ async fn serial_resume() { let resume_at = { let initial_messages = [ - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel_b, &sender, &fixtures::now()).await, ]; // First subscription @@ -271,12 +281,14 @@ async fn serial_resume() { .await; for message in &initial_messages { - assert!(events.iter().any(|event| event == message)); + assert!(events + .iter() + .any(|event| fixtures::event::message_sent(event, message))); } let event = events.last().expect("this vec is non-empty"); - Sequence::from(event) + event.sequence() }; // Resume after disconnect @@ -285,8 +297,8 @@ async fn serial_resume() { // Note that channel_b does not appear here. The buggy behaviour // would be masked if channel_b happened to send a new message // into the resumed event stream. - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await, ]; // Second subscription @@ -307,12 +319,14 @@ async fn serial_resume() { .await; for message in &resume_messages { - assert!(events.iter().any(|event| event == message)); + assert!(events + .iter() + .any(|event| fixtures::event::message_sent(event, message))); } let event = events.last().expect("this vec is non-empty"); - Sequence::from(event) + event.sequence() }; // Resume after disconnect a second time @@ -321,8 +335,8 @@ async fn serial_resume() { // problem. The resume point should before both of these messages, but // after _all_ prior messages. let final_messages = [ - fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + fixtures::message::send(&app, &channel_a, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel_b, &sender, &fixtures::now()).await, ]; // Third subscription @@ -345,7 +359,9 @@ async fn serial_resume() { // This set of messages, in particular, _should not_ include any prior // messages from `initial_messages` or `resume_messages`. for message in &final_messages { - assert!(events.iter().any(|event| event == message)); + assert!(events + .iter() + .any(|event| fixtures::event::message_sent(event, message))); } }; } @@ -378,13 +394,17 @@ async fn terminates_on_token_expiry() { // These should not be delivered. let messages = [ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, ]; assert!(events - .filter(|event| future::ready(messages.contains(event))) + .filter(|event| future::ready( + messages + .iter() + .any(|message| fixtures::event::message_sent(event, message)) + )) .next() .immediately() .await @@ -425,13 +445,17 @@ async fn terminates_on_logout() { // These should not be delivered. let messages = [ - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, - fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, + fixtures::message::send(&app, &channel, &sender, &fixtures::now()).await, ]; assert!(events - .filter(|event| future::ready(messages.contains(event))) + .filter(|event| future::ready( + messages + .iter() + .any(|message| fixtures::event::message_sent(event, message)) + )) .next() .immediately() .await diff --git a/src/event/sequence.rs b/src/event/sequence.rs index 9ebddd7..c566156 100644 --- a/src/event/sequence.rs +++ b/src/event/sequence.rs @@ -1,5 +1,20 @@ use std::fmt; +use crate::clock::DateTime; + +#[derive(Clone, Copy, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Instant { + pub at: DateTime, + #[serde(skip)] + pub sequence: Sequence, +} + +impl From for Sequence { + fn from(instant: Instant) -> Self { + instant.sequence + } +} + #[derive( Clone, Copy, @@ -22,3 +37,47 @@ impl fmt::Display for Sequence { value.fmt(f) } } + +impl Sequence { + pub fn up_to(resume_point: Option) -> impl for<'e> Fn(&'e E) -> bool + where + E: Sequenced, + { + move |event| resume_point.map_or(true, |resume_point| event.sequence() <= resume_point) + } + + pub fn after(resume_point: Option) -> impl for<'e> Fn(&'e E) -> bool + where + E: Sequenced, + { + move |event| resume_point < Some(event.sequence()) + } + + pub fn start_from(resume_point: Self) -> impl for<'e> Fn(&'e E) -> bool + where + E: Sequenced, + { + move |event| resume_point <= event.sequence() + } +} + +pub trait Sequenced { + fn instant(&self) -> Instant; + + fn sequence(&self) -> Sequence { + self.instant().into() + } +} + +impl Sequenced for &E +where + E: Sequenced, +{ + fn instant(&self) -> Instant { + (*self).instant() + } + + fn sequence(&self) -> Sequence { + (*self).sequence() + } +} diff --git a/src/event/types.rs b/src/event/types.rs deleted file mode 100644 index 2324dc1..0000000 --- a/src/event/types.rs +++ /dev/null @@ -1,85 +0,0 @@ -use crate::{ - channel::{self, Channel}, - event::{Instant, Sequence}, - login::Login, - message::{self, Message}, -}; - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct ChannelEvent { - #[serde(flatten)] - pub instant: Instant, - #[serde(flatten)] - pub data: ChannelEventData, -} - -impl ChannelEvent { - pub fn created(channel: Channel) -> Self { - Self { - instant: channel.created, - data: CreatedEvent { channel }.into(), - } - } -} - -impl<'c> From<&'c ChannelEvent> for Sequence { - fn from(event: &'c ChannelEvent) -> Self { - event.instant.sequence - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum ChannelEventData { - Created(CreatedEvent), - Message(MessageEvent), - MessageDeleted(MessageDeletedEvent), - Deleted(DeletedEvent), -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct CreatedEvent { - pub channel: Channel, -} - -impl From for ChannelEventData { - fn from(event: CreatedEvent) -> Self { - Self::Created(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct MessageEvent { - pub channel: Channel, - pub sender: Login, - pub message: Message, -} - -impl From for ChannelEventData { - fn from(event: MessageEvent) -> Self { - Self::Message(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct MessageDeletedEvent { - pub channel: Channel, - pub message: message::Id, -} - -impl From for ChannelEventData { - fn from(event: MessageDeletedEvent) -> Self { - Self::MessageDeleted(event) - } -} - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct DeletedEvent { - pub channel: channel::Id, -} - -impl From for ChannelEventData { - fn from(event: DeletedEvent) -> Self { - Self::Deleted(event) - } -} diff --git a/src/expire.rs b/src/expire.rs index a8eb8ad..e50bcb4 100644 --- a/src/expire.rs +++ b/src/expire.rs @@ -14,7 +14,7 @@ pub async fn middleware( next: Next, ) -> Result { app.tokens().expire(&expired_at).await?; - app.events().expire(&expired_at).await?; + app.messages().expire(&expired_at).await?; app.channels().expire(&expired_at).await?; Ok(next.run(req).await) } diff --git a/src/message/app.rs b/src/message/app.rs new file mode 100644 index 0000000..51f772e --- /dev/null +++ b/src/message/app.rs @@ -0,0 +1,88 @@ +use chrono::TimeDelta; +use itertools::Itertools; +use sqlx::sqlite::SqlitePool; + +use super::{repo::Provider as _, Message}; +use crate::{ + channel::{self, repo::Provider as _}, + clock::DateTime, + db::NotFound as _, + event::{broadcaster::Broadcaster, repo::Provider as _, Sequence}, + login::Login, +}; + +pub struct Messages<'a> { + db: &'a SqlitePool, + events: &'a Broadcaster, +} + +impl<'a> Messages<'a> { + pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { + Self { db, events } + } + + pub async fn send( + &self, + channel: &channel::Id, + sender: &Login, + sent_at: &DateTime, + body: &str, + ) -> Result { + let mut tx = self.db.begin().await?; + let channel = tx + .channels() + .by_id(channel) + .await + .not_found(|| Error::ChannelNotFound(channel.clone()))?; + let sent = tx.sequence().next(sent_at).await?; + let message = tx + .messages() + .create(&channel.snapshot(), sender, &sent, body) + .await?; + tx.commit().await?; + + for event in message.events() { + self.events.broadcast(event); + } + + Ok(message.snapshot()) + } + + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { + // Somewhat arbitrarily, expire after 90 days. + let expire_at = relative_to.to_owned() - TimeDelta::days(90); + + let mut tx = self.db.begin().await?; + let expired = tx.messages().expired(&expire_at).await?; + + let mut events = Vec::with_capacity(expired.len()); + for (channel, message) in expired { + let deleted = tx.sequence().next(relative_to).await?; + let message = tx.messages().delete(&channel, &message, &deleted).await?; + events.push( + message + .events() + .filter(Sequence::start_from(deleted.sequence)), + ); + } + + tx.commit().await?; + + for event in events + .into_iter() + .kmerge_by(|a, b| a.instant.sequence < b.instant.sequence) + { + self.events.broadcast(event); + } + + Ok(()) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("channel {0} not found")] + ChannelNotFound(channel::Id), + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), +} diff --git a/src/message/event.rs b/src/message/event.rs new file mode 100644 index 0000000..bcc2238 --- /dev/null +++ b/src/message/event.rs @@ -0,0 +1,50 @@ +use super::{snapshot::Message, Id}; +use crate::{ + channel::Channel, + event::{Instant, Sequenced}, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Event { + #[serde(flatten)] + pub instant: Instant, + #[serde(flatten)] + pub kind: Kind, +} + +impl Sequenced for Event { + fn instant(&self) -> Instant { + self.instant + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Kind { + Sent(Sent), + Deleted(Deleted), +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Sent { + #[serde(flatten)] + pub message: Message, +} + +impl From for Kind { + fn from(event: Sent) -> Self { + Self::Sent(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Deleted { + pub channel: Channel, + pub message: Id, +} + +impl From for Kind { + fn from(event: Deleted) -> Self { + Self::Deleted(event) + } +} diff --git a/src/message/history.rs b/src/message/history.rs new file mode 100644 index 0000000..5aca47e --- /dev/null +++ b/src/message/history.rs @@ -0,0 +1,43 @@ +use super::{ + event::{Deleted, Event, Sent}, + Message, +}; +use crate::event::Instant; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct History { + pub message: Message, + pub sent: Instant, + pub deleted: Option, +} + +impl History { + fn sent(&self) -> Event { + Event { + instant: self.sent, + kind: Sent { + message: self.message.clone(), + } + .into(), + } + } + + fn deleted(&self) -> Option { + self.deleted.map(|instant| Event { + instant, + kind: Deleted { + channel: self.message.channel.clone(), + message: self.message.id.clone(), + } + .into(), + }) + } + + pub fn events(&self) -> impl Iterator { + [self.sent()].into_iter().chain(self.deleted()) + } + + pub fn snapshot(&self) -> Message { + self.message.clone() + } +} diff --git a/src/message/mod.rs b/src/message/mod.rs index 9a9bf14..52d56c1 100644 --- a/src/message/mod.rs +++ b/src/message/mod.rs @@ -1,9 +1,8 @@ +pub mod app; +pub mod event; +mod history; mod id; +pub mod repo; +mod snapshot; -pub use self::id::Id; - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Message { - pub id: Id, - pub body: String, -} +pub use self::{event::Event, history::History, id::Id, snapshot::Message}; diff --git a/src/message/repo.rs b/src/message/repo.rs new file mode 100644 index 0000000..3b2b8f7 --- /dev/null +++ b/src/message/repo.rs @@ -0,0 +1,214 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use super::{snapshot::Message, History, Id}; +use crate::{ + channel::{self, Channel}, + clock::DateTime, + event::{Instant, Sequence}, + login::{self, Login}, +}; + +pub trait Provider { + fn messages(&mut self) -> Messages; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn messages(&mut self) -> Messages { + Messages(self) + } +} + +pub struct Messages<'t>(&'t mut SqliteConnection); + +impl<'c> Messages<'c> { + pub async fn create( + &mut self, + channel: &Channel, + sender: &Login, + sent: &Instant, + body: &str, + ) -> Result { + let id = Id::generate(); + + let message = sqlx::query!( + r#" + insert into message + (id, channel, sender, sent_at, sent_sequence, body) + values ($1, $2, $3, $4, $5, $6) + returning + id as "id: Id", + body + "#, + id, + channel.id, + sender.id, + sent.at, + sent.sequence, + body, + ) + .map(|row| History { + message: Message { + channel: channel.clone(), + sender: sender.clone(), + id: row.id, + body: row.body, + }, + sent: *sent, + deleted: None, + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + async fn by_id(&mut self, channel: &Channel, message: &Id) -> Result { + let message = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + sender.id as "sender_id: login::Id", + sender.name as "sender_name", + message.id as "id: Id", + message.body, + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence" + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where message.id = $1 + and message.channel = $2 + "#, + message, + channel.id, + ) + .map(|row| History { + message: Message { + channel: Channel { + id: row.channel_id, + name: row.channel_name, + }, + sender: Login { + id: row.sender_id, + name: row.sender_name, + }, + id: row.id, + body: row.body, + }, + sent: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, + deleted: None, + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + pub async fn delete( + &mut self, + channel: &Channel, + message: &Id, + deleted: &Instant, + ) -> Result { + let history = self.by_id(channel, message).await?; + + sqlx::query_scalar!( + r#" + delete from message + where + id = $1 + returning 1 as "deleted: i64" + "#, + history.message.id, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(History { + deleted: Some(*deleted), + ..history + }) + } + + pub async fn expired( + &mut self, + expire_at: &DateTime, + ) -> Result, sqlx::Error> { + let messages = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + message.id as "message: Id" + from message + join channel on message.channel = channel.id + where sent_at < $1 + "#, + expire_at, + ) + .map(|row| { + ( + Channel { + id: row.channel_id, + name: row.channel_name, + }, + row.message, + ) + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } + + pub async fn replay( + &mut self, + resume_at: Option, + ) -> Result, sqlx::Error> { + let messages = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + sender.id as "sender_id: login::Id", + sender.name as "sender_name", + message.id as "id: Id", + message.body, + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence" + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where coalesce(message.sent_sequence > $1, true) + "#, + resume_at, + ) + .map(|row| History { + message: Message { + channel: Channel { + id: row.channel_id, + name: row.channel_name, + }, + sender: Login { + id: row.sender_id, + name: row.sender_name, + }, + id: row.id, + body: row.body, + }, + sent: Instant { + at: row.sent_at, + sequence: row.sent_sequence, + }, + deleted: None, + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } +} diff --git a/src/message/snapshot.rs b/src/message/snapshot.rs new file mode 100644 index 0000000..3adccbe --- /dev/null +++ b/src/message/snapshot.rs @@ -0,0 +1,74 @@ +use super::{ + event::{Event, Kind, Sent}, + Id, +}; +use crate::{channel::Channel, login::Login}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(into = "self::serialize::Message")] +pub struct Message { + pub channel: Channel, + pub sender: Login, + pub id: Id, + pub body: String, +} + +mod serialize { + use crate::{channel::Channel, login::Login, message::Id}; + + #[derive(serde::Serialize)] + pub struct Message { + channel: Channel, + sender: Login, + #[allow(clippy::struct_field_names)] + // Deliberately redundant with the module path; this produces a specific serialization. + message: MessageData, + } + + #[derive(serde::Serialize)] + pub struct MessageData { + id: Id, + body: String, + } + + impl From for Message { + fn from(message: super::Message) -> Self { + Self { + channel: message.channel, + sender: message.sender, + message: MessageData { + id: message.id, + body: message.body, + }, + } + } + } +} + +impl Message { + fn apply(state: Option, event: Event) -> Option { + match (state, event.kind) { + (None, Kind::Sent(event)) => Some(event.into()), + (Some(message), Kind::Deleted(event)) if message.id == event.message => None, + (state, event) => panic!("invalid message event {event:#?} for state {state:#?}"), + } + } +} + +impl FromIterator for Option { + fn from_iter>(events: I) -> Self { + events.into_iter().fold(None, Message::apply) + } +} + +impl From<&Sent> for Message { + fn from(event: &Sent) -> Self { + event.message.clone() + } +} + +impl From for Message { + fn from(event: Sent) -> Self { + event.message + } +} diff --git a/src/test/fixtures/event.rs b/src/test/fixtures/event.rs new file mode 100644 index 0000000..09f0490 --- /dev/null +++ b/src/test/fixtures/event.rs @@ -0,0 +1,11 @@ +use crate::{ + event::{Event, Kind}, + message::Message, +}; + +pub fn message_sent(event: &Event, message: &Message) -> bool { + matches!( + &event.kind, + Kind::MessageSent(event) if message == &event.into() + ) +} diff --git a/src/test/fixtures/filter.rs b/src/test/fixtures/filter.rs index d1939a5..6e62aea 100644 --- a/src/test/fixtures/filter.rs +++ b/src/test/fixtures/filter.rs @@ -1,11 +1,11 @@ use futures::future; -use crate::event::types; +use crate::event::{Event, Kind}; -pub fn messages() -> impl FnMut(&types::ChannelEvent) -> future::Ready { - |event| future::ready(matches!(event.data, types::ChannelEventData::Message(_))) +pub fn messages() -> impl FnMut(&Event) -> future::Ready { + |event| future::ready(matches!(event.kind, Kind::MessageSent(_))) } -pub fn created() -> impl FnMut(&types::ChannelEvent) -> future::Ready { - |event| future::ready(matches!(event.data, types::ChannelEventData::Created(_))) +pub fn created() -> impl FnMut(&Event) -> future::Ready { + |event| future::ready(matches!(event.kind, Kind::ChannelCreated(_))) } diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs index fd50887..381b10b 100644 --- a/src/test/fixtures/message.rs +++ b/src/test/fixtures/message.rs @@ -1,17 +1,12 @@ use faker_rand::lorem::Paragraphs; -use crate::{app::App, channel::Channel, clock::RequestedAt, event::types, login::Login}; +use crate::{app::App, channel::Channel, clock::RequestedAt, login::Login, message::Message}; -pub async fn send( - app: &App, - login: &Login, - channel: &Channel, - sent_at: &RequestedAt, -) -> types::ChannelEvent { +pub async fn send(app: &App, channel: &Channel, login: &Login, sent_at: &RequestedAt) -> Message { let body = propose(); - app.events() - .send(login, &channel.id, &body, sent_at) + app.messages() + .send(&channel.id, login, sent_at, &body) .await .expect("should succeed if the channel exists") } diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs index 76467ab..c5efa9b 100644 --- a/src/test/fixtures/mod.rs +++ b/src/test/fixtures/mod.rs @@ -3,6 +3,7 @@ use chrono::{TimeDelta, Utc}; use crate::{app::App, clock::RequestedAt, db}; pub mod channel; +pub mod event; pub mod filter; pub mod future; pub mod identity; diff --git a/src/token/app.rs b/src/token/app.rs index 030ec69..5c4fcd5 100644 --- a/src/token/app.rs +++ b/src/token/app.rs @@ -127,7 +127,7 @@ impl<'a> Tokens<'a> { tx.commit().await?; for event in tokens.into_iter().map(event::TokenRevoked::from) { - self.tokens.broadcast(&event); + self.tokens.broadcast(event); } Ok(()) @@ -139,7 +139,7 @@ impl<'a> Tokens<'a> { tx.commit().await?; self.tokens - .broadcast(&event::TokenRevoked::from(token.clone())); + .broadcast(event::TokenRevoked::from(token.clone())); Ok(()) } -- cgit v1.2.3