diff options
Diffstat (limited to 'src')
34 files changed, 931 insertions, 809 deletions
@@ -13,9 +13,9 @@ pub struct App { } impl App { - pub async fn from(db: SqlitePool) -> Result<Self, sqlx::Error> { - let broadcaster = Broadcaster::from_database(&db).await?; - Ok(Self { db, broadcaster }) + pub fn from(db: SqlitePool) -> Self { + let broadcaster = Broadcaster::default(); + Self { db, broadcaster } } } diff --git a/src/channel/app.rs b/src/channel/app.rs index 793fa35..d7312e4 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -1,7 +1,9 @@ +use chrono::TimeDelta; use sqlx::sqlite::SqlitePool; use crate::{ - events::broadcaster::Broadcaster, + clock::DateTime, + events::{broadcaster::Broadcaster, repo::message::Provider as _, types::ChannelEvent}, repo::channel::{Channel, Provider as _}, }; @@ -15,16 +17,18 @@ impl<'a> Channels<'a> { Self { db, broadcaster } } - pub async fn create(&self, name: &str) -> Result<Channel, CreateError> { + pub async fn create(&self, name: &str, created_at: &DateTime) -> Result<Channel, CreateError> { let mut tx = self.db.begin().await?; let channel = tx .channels() - .create(name) + .create(name, created_at) .await .map_err(|err| CreateError::from_duplicate_name(err, name))?; - self.broadcaster.register_channel(&channel.id); tx.commit().await?; + self.broadcaster + .broadcast(&ChannelEvent::created(channel.clone())); + Ok(channel) } @@ -35,6 +39,32 @@ impl<'a> Channels<'a> { Ok(channels) } + + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { + // Somewhat arbitrarily, expire after 90 days. + let expire_at = relative_to.to_owned() - TimeDelta::days(90); + + let mut tx = self.db.begin().await?; + let expired = tx.channels().expired(&expire_at).await?; + + let mut events = Vec::with_capacity(expired.len()); + for channel in expired { + let sequence = tx.message_events().assign_sequence(&channel).await?; + let event = tx + .channels() + .delete_expired(&channel, sequence, relative_to) + .await?; + events.push(event); + } + + tx.commit().await?; + + for event in events { + self.broadcaster.broadcast(&event); + } + + Ok(()) + } } #[derive(Debug, thiserror::Error)] diff --git a/src/channel/routes.rs b/src/channel/routes.rs index f524e62..1f8db5a 100644 --- a/src/channel/routes.rs +++ b/src/channel/routes.rs @@ -52,11 +52,12 @@ struct CreateRequest { async fn on_create( State(app): State<App>, _: Login, // requires auth, but doesn't actually care who you are + RequestedAt(created_at): RequestedAt, Json(form): Json<CreateRequest>, ) -> Result<Json<Channel>, CreateError> { let channel = app .channels() - .create(&form.name) + .create(&form.name, &created_at) .await .map_err(CreateError)?; diff --git a/src/channel/routes/test/list.rs b/src/channel/routes/test/list.rs index f7f7b44..bc94024 100644 --- a/src/channel/routes/test/list.rs +++ b/src/channel/routes/test/list.rs @@ -26,7 +26,7 @@ async fn one_channel() { let app = fixtures::scratch_app().await; let viewer = fixtures::login::create(&app).await; - let channel = fixtures::channel::create(&app).await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; // Call the endpoint @@ -46,8 +46,8 @@ async fn multiple_channels() { let app = fixtures::scratch_app().await; let viewer = fixtures::login::create(&app).await; let channels = vec![ - fixtures::channel::create(&app).await, - fixtures::channel::create(&app).await, + fixtures::channel::create(&app, &fixtures::now()).await, + fixtures::channel::create(&app, &fixtures::now()).await, ]; // Call the endpoint diff --git a/src/channel/routes/test/on_create.rs b/src/channel/routes/test/on_create.rs index 23885c0..e2610a5 100644 --- a/src/channel/routes/test/on_create.rs +++ b/src/channel/routes/test/on_create.rs @@ -1,8 +1,10 @@ use axum::extract::{Json, State}; +use futures::stream::StreamExt as _; use crate::{ channel::{app, routes}, - test::fixtures, + events::types, + test::fixtures::{self, future::Immediately as _}, }; #[tokio::test] @@ -16,10 +18,14 @@ async fn new_channel() { let name = fixtures::channel::propose(); let request = routes::CreateRequest { name }; - let Json(response_channel) = - routes::on_create(State(app.clone()), creator, Json(request.clone())) - .await - .expect("new channel in an empty app"); + let Json(response_channel) = routes::on_create( + State(app.clone()), + creator, + fixtures::now(), + Json(request.clone()), + ) + .await + .expect("new channel in an empty app"); // Verify the structure of the response @@ -28,8 +34,27 @@ async fn new_channel() { // Verify the semantics let channels = app.channels().all().await.expect("always succeeds"); - assert!(channels.contains(&response_channel)); + + let mut events = app + .events() + .subscribe(types::ResumePoint::default()) + .await + .expect("subscribing never fails") + .filter(fixtures::filter::created()); + + let types::ResumableEvent(_, event) = events + .next() + .immediately() + .await + .expect("creation event published"); + + assert_eq!(types::Sequence::default(), event.sequence); + assert!(matches!( + event.data, + types::ChannelEventData::Created(event) + if event.channel == response_channel + )); } #[tokio::test] @@ -38,15 +63,19 @@ async fn duplicate_name() { let app = fixtures::scratch_app().await; let creator = fixtures::login::create(&app).await; - let channel = fixtures::channel::create(&app).await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; // Call the endpoint let request = routes::CreateRequest { name: channel.name }; - let routes::CreateError(error) = - routes::on_create(State(app.clone()), creator, Json(request.clone())) - .await - .expect_err("duplicate channel name"); + let routes::CreateError(error) = routes::on_create( + State(app.clone()), + creator, + fixtures::now(), + Json(request.clone()), + ) + .await + .expect_err("duplicate channel name"); // Verify the structure of the response diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index 93a5480..233518b 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -1,90 +1,33 @@ -use axum::{ - extract::{Json, Path, State}, - http::StatusCode, -}; +use axum::extract::{Json, Path, State}; use futures::stream::StreamExt; use crate::{ channel::routes, - events::app, + events::{app, types}, repo::channel, test::fixtures::{self, future::Immediately as _}, }; #[tokio::test] -async fn channel_exists() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let sender = fixtures::login::create(&app).await; - let channel = fixtures::channel::create(&app).await; - - // Call the endpoint - - let sent_at = fixtures::now(); - let request = routes::SendRequest { - message: fixtures::message::propose(), - }; - let status = routes::on_send( - State(app.clone()), - Path(channel.id.clone()), - sent_at.clone(), - sender.clone(), - Json(request.clone()), - ) - .await - .expect("sending to a valid channel"); - - // Verify the structure of the response - - assert_eq!(StatusCode::ACCEPTED, status); - - // Verify the semantics - - let subscribed_at = fixtures::now(); - let mut events = app - .events() - .subscribe(&channel.id, &subscribed_at, None) - .await - .expect("subscribing to a valid channel"); - - let event = events - .next() - .immediately() - .await - .expect("event received by subscribers"); - - assert_eq!(request.message, event.body); - assert_eq!(sender, event.sender); - assert_eq!(*sent_at, event.sent_at); -} - -#[tokio::test] async fn messages_in_order() { // Set up the environment let app = fixtures::scratch_app().await; let sender = fixtures::login::create(&app).await; - let channel = fixtures::channel::create(&app).await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; // Call the endpoint (twice) let requests = vec![ - ( - fixtures::now(), - routes::SendRequest { - message: fixtures::message::propose(), - }, - ), - ( - fixtures::now(), - routes::SendRequest { - message: fixtures::message::propose(), - }, - ), + (fixtures::now(), fixtures::message::propose()), + (fixtures::now(), fixtures::message::propose()), ]; - for (sent_at, request) in &requests { + for (sent_at, message) in &requests { + let request = routes::SendRequest { + message: message.clone(), + }; + routes::on_send( State(app.clone()), Path(channel.id.clone()), @@ -98,20 +41,24 @@ async fn messages_in_order() { // Verify the semantics - let subscribed_at = fixtures::now(); let events = app .events() - .subscribe(&channel.id, &subscribed_at, None) + .subscribe(types::ResumePoint::default()) .await .expect("subscribing to a valid channel") + .filter(fixtures::filter::messages()) .take(requests.len()); let events = events.collect::<Vec<_>>().immediately().await; - for ((sent_at, request), event) in requests.into_iter().zip(events) { - assert_eq!(request.message, event.body); - assert_eq!(sender, event.sender); - assert_eq!(*sent_at, event.sent_at); + for ((sent_at, message), types::ResumableEvent(_, event)) in requests.into_iter().zip(events) { + assert_eq!(*sent_at, event.at); + assert!(matches!( + event.data, + types::ChannelEventData::Message(event_message) + if event_message.sender == sender + && event_message.message.body == message + )); } } @@ -10,7 +10,7 @@ use clap::Parser; use sqlx::sqlite::SqlitePool; use tokio::net; -use crate::{app::App, channel, clock, events, login, repo::pool}; +use crate::{app::App, channel, clock, events, expire, login, repo::pool}; /// Command-line entry point for running the `hi` server. /// @@ -70,8 +70,12 @@ impl Args { pub async fn run(self) -> Result<(), Error> { let pool = self.pool().await?; - let app = App::from(pool).await?; + let app = App::from(pool); let app = routers() + .route_layer(middleware::from_fn_with_state( + app.clone(), + expire::middleware, + )) .route_layer(middleware::from_fn(clock::middleware)) .with_state(app); diff --git a/src/events/app.rs b/src/events/app.rs index 7229551..0cdc641 100644 --- a/src/events/app.rs +++ b/src/events/app.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use chrono::TimeDelta; use futures::{ future, @@ -8,7 +10,8 @@ use sqlx::sqlite::SqlitePool; use super::{ broadcaster::Broadcaster, - repo::broadcast::{self, Provider as _}, + repo::message::Provider as _, + types::{self, ChannelEvent, ResumePoint}, }; use crate::{ clock::DateTime, @@ -35,64 +38,87 @@ impl<'a> Events<'a> { channel: &channel::Id, body: &str, sent_at: &DateTime, - ) -> Result<broadcast::Message, EventsError> { + ) -> Result<types::ChannelEvent, EventsError> { let mut tx = self.db.begin().await?; let channel = tx .channels() .by_id(channel) .await .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; - let message = tx - .broadcast() + let event = tx + .message_events() .create(login, &channel, body, sent_at) .await?; tx.commit().await?; - self.broadcaster.broadcast(&channel.id, &message); - Ok(message) + self.broadcaster.broadcast(&event); + Ok(event) } - pub async fn subscribe( - &self, - channel: &channel::Id, - subscribed_at: &DateTime, - resume_at: Option<broadcast::Sequence>, - ) -> Result<impl Stream<Item = broadcast::Message> + std::fmt::Debug, EventsError> { + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { // Somewhat arbitrarily, expire after 90 days. - let expire_at = subscribed_at.to_owned() - TimeDelta::days(90); + let expire_at = relative_to.to_owned() - TimeDelta::days(90); let mut tx = self.db.begin().await?; - let channel = tx - .channels() - .by_id(channel) - .await - .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; + let expired = tx.message_events().expired(&expire_at).await?; + + let mut events = Vec::with_capacity(expired.len()); + for (channel, message) in expired { + let sequence = tx.message_events().assign_sequence(&channel).await?; + let event = tx + .message_events() + .delete_expired(&channel, &message, sequence, relative_to) + .await?; + events.push(event); + } + + tx.commit().await?; + + for event in events { + self.broadcaster.broadcast(&event); + } + + Ok(()) + } + + pub async fn subscribe( + &self, + resume_at: ResumePoint, + ) -> Result<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug, sqlx::Error> { + let mut tx = self.db.begin().await?; + let channels = tx.channels().all().await?; + + let created_events = { + let resume_at = resume_at.clone(); + let channels = channels.clone(); + stream::iter( + channels + .into_iter() + .map(ChannelEvent::created) + .filter(move |event| resume_at.not_after(event)), + ) + }; // Subscribe before retrieving, to catch messages broadcast while we're // querying the DB. We'll prune out duplicates later. - let live_messages = self.broadcaster.subscribe(&channel.id); + let live_messages = self.broadcaster.subscribe(); - tx.broadcast().expire(&expire_at).await?; - let stored_messages = tx.broadcast().replay(&channel, resume_at).await?; - tx.commit().await?; + let mut replays = BTreeMap::new(); + let mut resume_live_at = resume_at.clone(); + for channel in channels { + let replay = tx + .message_events() + .replay(&channel, resume_at.get(&channel.id)) + .await?; - let resume_broadcast_at = stored_messages - .last() - .map(|message| message.sequence) - .or(resume_at); + if let Some(last) = replay.last() { + resume_live_at.advance(last); + } - // This should always be the case, up to integer rollover, primarily - // because every message in stored_messages has a sequence not less - // than `resume_at`, or `resume_at` is None. We use the last message - // (if any) to decide when to resume the `live_messages` stream. - // - // It probably simplifies to assert!(resume_at <= resume_broadcast_at), but - // this form captures more of the reasoning. - assert!( - (resume_at.is_none() && resume_broadcast_at.is_none()) - || (stored_messages.is_empty() && resume_at == resume_broadcast_at) - || resume_at < resume_broadcast_at - ); + replays.insert(channel.id.clone(), replay); + } + + let replay = stream::select_all(replays.into_values().map(stream::iter)); // no skip_expired or resume transforms for stored_messages, as it's // constructed not to contain messages meeting either criterion. @@ -100,34 +126,31 @@ impl<'a> Events<'a> { // * skip_expired is redundant with the `tx.broadcasts().expire(…)` call; // * resume is redundant with the resume_at argument to // `tx.broadcasts().replay(…)`. - let stored_messages = stream::iter(stored_messages); let live_messages = live_messages - // Sure, it's temporally improbable that we'll ever skip a message - // that's 90 days old, but there's no reason not to be thorough. - .filter(Self::skip_expired(&expire_at)) // Filtering on the broadcast resume point filters out messages // before resume_at, and filters out messages duplicated from // stored_messages. - .filter(Self::resume(resume_broadcast_at)); + .filter(Self::resume(resume_live_at)); + + Ok(created_events.chain(replay).chain(live_messages).scan( + resume_at, + |resume_point, event| { + match event.data { + types::ChannelEventData::Deleted(_) => resume_point.forget(&event), + _ => resume_point.advance(&event), + } - Ok(stored_messages.chain(live_messages)) + let event = types::ResumableEvent(resume_point.clone(), event); + + future::ready(Some(event)) + }, + )) } fn resume( - resume_at: Option<broadcast::Sequence>, - ) -> impl for<'m> FnMut(&'m broadcast::Message) -> future::Ready<bool> { - move |msg| { - future::ready(match resume_at { - None => true, - Some(resume_at) => msg.sequence > resume_at, - }) - } - } - fn skip_expired( - expire_at: &DateTime, - ) -> impl for<'m> FnMut(&'m broadcast::Message) -> future::Ready<bool> { - let expire_at = expire_at.to_owned(); - move |msg| future::ready(msg.sent_at > expire_at) + resume_at: ResumePoint, + ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready<bool> { + move |event| future::ready(resume_at.not_after(event)) } } diff --git a/src/events/broadcaster.rs b/src/events/broadcaster.rs index dcaba91..9697c0a 100644 --- a/src/events/broadcaster.rs +++ b/src/events/broadcaster.rs @@ -1,63 +1,35 @@ -use std::collections::{hash_map::Entry, HashMap}; -use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::{Arc, Mutex}; use futures::{future, stream::StreamExt as _, Stream}; -use sqlx::sqlite::SqlitePool; use tokio::sync::broadcast::{channel, Sender}; use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; -use crate::{ - events::repo::broadcast, - repo::channel::{self, Provider as _}, -}; +use crate::events::types; -// Clones will share the same senders collection. +// Clones will share the same sender. #[derive(Clone)] pub struct Broadcaster { // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that // lock it must be sync. - senders: Arc<Mutex<HashMap<channel::Id, Sender<broadcast::Message>>>>, + senders: Arc<Mutex<Sender<types::ChannelEvent>>>, } -impl Broadcaster { - pub async fn from_database(db: &SqlitePool) -> Result<Self, sqlx::Error> { - let mut tx = db.begin().await?; - let channels = tx.channels().all().await?; - tx.commit().await?; - - let channels = channels.iter().map(|c| &c.id); - let broadcaster = Self::new(channels); - Ok(broadcaster) - } - - fn new<'i>(channels: impl IntoIterator<Item = &'i channel::Id>) -> Self { - let senders: HashMap<_, _> = channels - .into_iter() - .cloned() - .map(|id| (id, Self::make_sender())) - .collect(); +impl Default for Broadcaster { + fn default() -> Self { + let sender = Self::make_sender(); Self { - senders: Arc::new(Mutex::new(senders)), - } - } - - // panic: if ``channel`` is already registered. - pub fn register_channel(&self, channel: &channel::Id) { - match self.senders().entry(channel.clone()) { - // This ever happening indicates a serious logic error. - Entry::Occupied(_) => panic!("duplicate channel registration for channel {channel}"), - Entry::Vacant(entry) => { - entry.insert(Self::make_sender()); - } + senders: Arc::new(Mutex::new(sender)), } } +} - // panic: if ``channel`` has not been previously registered, and was not - // part of the initial set of channels. - pub fn broadcast(&self, channel: &channel::Id, message: &broadcast::Message) { - let tx = self.sender(channel); +impl Broadcaster { + // panic: if ``message.channel.id`` has not been previously registered, + // and was not part of the initial set of channels. + pub fn broadcast(&self, message: &types::ChannelEvent) { + let tx = self.sender(); // Per the Tokio docs, the returned error is only used to indicate that // there are no receivers. In this use case, that's fine; a lack of @@ -71,15 +43,12 @@ impl Broadcaster { // panic: if ``channel`` has not been previously registered, and was not // part of the initial set of channels. - pub fn subscribe( - &self, - channel: &channel::Id, - ) -> impl Stream<Item = broadcast::Message> + std::fmt::Debug { - let rx = self.sender(channel).subscribe(); + pub fn subscribe(&self) -> impl Stream<Item = types::ChannelEvent> + std::fmt::Debug { + let rx = self.sender().subscribe(); BroadcastStream::from(rx).scan((), |(), r| { future::ready(match r { - Ok(message) => Some(message), + Ok(event) => Some(event), // Stop the stream here. This will disconnect SSE clients // (see `routes.rs`), who will then resume from // `Last-Event-ID`, allowing them to catch up by reading @@ -92,17 +61,11 @@ impl Broadcaster { }) } - // panic: if ``channel`` has not been previously registered, and was not - // part of the initial set of channels. - fn sender(&self, channel: &channel::Id) -> Sender<broadcast::Message> { - self.senders()[channel].clone() - } - - fn senders(&self) -> MutexGuard<HashMap<channel::Id, Sender<broadcast::Message>>> { - self.senders.lock().unwrap() // propagate panics when mutex is poisoned + fn sender(&self) -> Sender<types::ChannelEvent> { + self.senders.lock().unwrap().clone() } - fn make_sender() -> Sender<broadcast::Message> { + fn make_sender() -> Sender<types::ChannelEvent> { // Queue depth of 16 chosen entirely arbitrarily. Don't read too much // into it. let (tx, _) = channel(16); diff --git a/src/events/mod.rs b/src/events/mod.rs index b9f3f5b..711ae64 100644 --- a/src/events/mod.rs +++ b/src/events/mod.rs @@ -3,5 +3,6 @@ pub mod broadcaster; mod extract; pub mod repo; mod routes; +pub mod types; pub use self::routes::router; diff --git a/src/events/repo/broadcast.rs b/src/events/repo/broadcast.rs deleted file mode 100644 index 6914573..0000000 --- a/src/events/repo/broadcast.rs +++ /dev/null @@ -1,162 +0,0 @@ -use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; - -use crate::{ - clock::DateTime, - repo::{ - channel::Channel, - login::{self, Login}, - message, - }, -}; - -pub trait Provider { - fn broadcast(&mut self) -> Broadcast; -} - -impl<'c> Provider for Transaction<'c, Sqlite> { - fn broadcast(&mut self) -> Broadcast { - Broadcast(self) - } -} - -pub struct Broadcast<'t>(&'t mut SqliteConnection); - -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Message { - pub id: message::Id, - #[serde(skip)] - pub sequence: Sequence, - pub sender: Login, - pub body: String, - pub sent_at: DateTime, -} - -impl<'c> Broadcast<'c> { - pub async fn create( - &mut self, - sender: &Login, - channel: &Channel, - body: &str, - sent_at: &DateTime, - ) -> Result<Message, sqlx::Error> { - let sequence = self.next_sequence_for(channel).await?; - - let id = message::Id::generate(); - - let message = sqlx::query!( - r#" - insert into message - (id, channel, sequence, sender, body, sent_at) - values ($1, $2, $3, $4, $5, $6) - returning - id as "id: message::Id", - sequence as "sequence: Sequence", - sender as "sender: login::Id", - body, - sent_at as "sent_at: DateTime" - "#, - id, - channel.id, - sequence, - sender.id, - body, - sent_at, - ) - .map(|row| Message { - id: row.id, - sequence: row.sequence, - sender: sender.clone(), - body: row.body, - sent_at: row.sent_at, - }) - .fetch_one(&mut *self.0) - .await?; - - Ok(message) - } - - async fn next_sequence_for(&mut self, channel: &Channel) -> Result<Sequence, sqlx::Error> { - let Sequence(current) = sqlx::query_scalar!( - r#" - -- `max` never returns null, but sqlx can't detect that - select max(sequence) as "sequence!: Sequence" - from message - where channel = $1 - "#, - channel.id, - ) - .fetch_one(&mut *self.0) - .await?; - - Ok(Sequence(current + 1)) - } - - pub async fn expire(&mut self, expire_at: &DateTime) -> Result<(), sqlx::Error> { - sqlx::query!( - r#" - delete from message - where sent_at < $1 - "#, - expire_at, - ) - .execute(&mut *self.0) - .await?; - - Ok(()) - } - - pub async fn replay( - &mut self, - channel: &Channel, - resume_at: Option<Sequence>, - ) -> Result<Vec<Message>, sqlx::Error> { - let messages = sqlx::query!( - r#" - select - message.id as "id: message::Id", - sequence as "sequence: Sequence", - login.id as "sender_id: login::Id", - login.name as sender_name, - message.body, - message.sent_at as "sent_at: DateTime" - from message - join login on message.sender = login.id - where channel = $1 - and coalesce(sequence > $2, true) - order by sequence asc - "#, - channel.id, - resume_at, - ) - .map(|row| Message { - id: row.id, - sequence: row.sequence, - sender: Login { - id: row.sender_id, - name: row.sender_name, - }, - body: row.body, - sent_at: row.sent_at, - }) - .fetch_all(&mut *self.0) - .await?; - - Ok(messages) - } -} - -#[derive( - Debug, - Eq, - Ord, - PartialEq, - PartialOrd, - Clone, - Copy, - serde::Serialize, - serde::Deserialize, - sqlx::Type, -)] -#[serde(transparent)] -#[sqlx(transparent)] -pub struct Sequence(i64); diff --git a/src/events/repo/message.rs b/src/events/repo/message.rs new file mode 100644 index 0000000..f8bae2b --- /dev/null +++ b/src/events/repo/message.rs @@ -0,0 +1,198 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::{ + clock::DateTime, + events::types::{self, Sequence}, + repo::{ + channel::{self, Channel}, + login::{self, Login}, + message::{self, Message}, + }, +}; + +pub trait Provider { + fn message_events(&mut self) -> Events; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn message_events(&mut self) -> Events { + Events(self) + } +} + +pub struct Events<'t>(&'t mut SqliteConnection); + +impl<'c> Events<'c> { + pub async fn create( + &mut self, + sender: &Login, + channel: &Channel, + body: &str, + sent_at: &DateTime, + ) -> Result<types::ChannelEvent, sqlx::Error> { + let sequence = self.assign_sequence(channel).await?; + + let id = message::Id::generate(); + + let message = sqlx::query!( + r#" + insert into message + (id, channel, sequence, sender, body, sent_at) + values ($1, $2, $3, $4, $5, $6) + returning + id as "id: message::Id", + sequence as "sequence: Sequence", + sender as "sender: login::Id", + body, + sent_at as "sent_at: DateTime" + "#, + id, + channel.id, + sequence, + sender.id, + body, + sent_at, + ) + .map(|row| types::ChannelEvent { + sequence: row.sequence, + at: row.sent_at, + data: types::MessageEvent { + channel: channel.clone(), + sender: sender.clone(), + message: Message { + id: row.id, + body: row.body, + }, + } + .into(), + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + pub async fn assign_sequence(&mut self, channel: &Channel) -> Result<Sequence, sqlx::Error> { + let next = sqlx::query_scalar!( + r#" + update channel + set last_sequence = last_sequence + 1 + where id = $1 + returning last_sequence as "next_sequence: Sequence" + "#, + channel.id, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(next) + } + + pub async fn delete_expired( + &mut self, + channel: &Channel, + message: &message::Id, + sequence: Sequence, + deleted_at: &DateTime, + ) -> Result<types::ChannelEvent, sqlx::Error> { + sqlx::query_scalar!( + r#" + delete from message + where id = $1 + returning 1 as "row: i64" + "#, + message, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(types::ChannelEvent { + sequence, + at: *deleted_at, + data: types::MessageDeletedEvent { + channel: channel.clone(), + message: message.clone(), + } + .into(), + }) + } + + pub async fn expired( + &mut self, + expire_at: &DateTime, + ) -> Result<Vec<(Channel, message::Id)>, sqlx::Error> { + let messages = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + channel.created_at as "channel_created_at: DateTime", + message.id as "message: message::Id" + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where sent_at < $1 + "#, + expire_at, + ) + .map(|row| { + ( + Channel { + id: row.channel_id, + name: row.channel_name, + created_at: row.channel_created_at, + }, + row.message, + ) + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } + + pub async fn replay( + &mut self, + channel: &Channel, + resume_at: Option<Sequence>, + ) -> Result<Vec<types::ChannelEvent>, sqlx::Error> { + let events = sqlx::query!( + r#" + select + message.id as "id: message::Id", + sequence as "sequence: Sequence", + login.id as "sender_id: login::Id", + login.name as sender_name, + message.body, + message.sent_at as "sent_at: DateTime" + from message + join login on message.sender = login.id + where channel = $1 + and coalesce(sequence > $2, true) + order by sequence asc + "#, + channel.id, + resume_at, + ) + .map(|row| types::ChannelEvent { + sequence: row.sequence, + at: row.sent_at, + data: types::MessageEvent { + channel: channel.clone(), + sender: login::Login { + id: row.sender_id, + name: row.sender_name, + }, + message: Message { + id: row.id, + body: row.body, + }, + } + .into(), + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(events) + } +} diff --git a/src/events/repo/mod.rs b/src/events/repo/mod.rs index 2ed3062..e216a50 100644 --- a/src/events/repo/mod.rs +++ b/src/events/repo/mod.rs @@ -1 +1 @@ -pub mod broadcast; +pub mod message; diff --git a/src/events/routes.rs b/src/events/routes.rs index d901f9b..89c942c 100644 --- a/src/events/routes.rs +++ b/src/events/routes.rs @@ -1,8 +1,5 @@ -use std::collections::{BTreeMap, HashSet}; - use axum::{ extract::State, - http::StatusCode, response::{ sse::{self, Sse}, IntoResponse, Response, @@ -10,87 +7,31 @@ use axum::{ routing::get, Router, }; -use axum_extra::extract::Query; -use futures::{ - future, - stream::{self, Stream, StreamExt as _, TryStreamExt as _}, -}; +use futures::stream::{Stream, StreamExt as _}; -use super::{extract::LastEventId, repo::broadcast}; -use crate::{ - app::App, - clock::RequestedAt, - error::Internal, - events::app::EventsError, - repo::{channel, login::Login}, +use super::{ + extract::LastEventId, + types::{self, ResumePoint}, }; +use crate::{app::App, error::Internal, repo::login::Login}; #[cfg(test)] mod test; -// For the purposes of event replay, an "event ID" is a vector of per-channel -// sequence numbers. Replay will start with messages whose sequence number in -// its channel is higher than the sequence in the event ID, or if the channel -// is not listed in the event ID, then at the beginning. -// -// Using a sorted map ensures that there is a canonical representation for -// each event ID. -type EventId = BTreeMap<channel::Id, broadcast::Sequence>; - pub fn router() -> Router<App> { Router::new().route("/api/events", get(events)) } -#[derive(Clone, serde::Deserialize)] -struct EventsQuery { - #[serde(default, rename = "channel")] - channels: HashSet<channel::Id>, -} - async fn events( State(app): State<App>, - RequestedAt(now): RequestedAt, _: Login, // requires auth, but doesn't actually care who you are - last_event_id: Option<LastEventId<EventId>>, - Query(query): Query<EventsQuery>, -) -> Result<Events<impl Stream<Item = ReplayableEvent> + std::fmt::Debug>, ErrorResponse> { + last_event_id: Option<LastEventId<ResumePoint>>, +) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, Internal> { let resume_at = last_event_id .map(LastEventId::into_inner) .unwrap_or_default(); - let streams = stream::iter(query.channels) - .then(|channel| { - let app = app.clone(); - let resume_at = resume_at.clone(); - async move { - let resume_at = resume_at.get(&channel).copied(); - - let events = app - .events() - .subscribe(&channel, &now, resume_at) - .await? - .map(ChannelEvent::wrap(channel)); - - Ok::<_, EventsError>(events) - } - }) - .try_collect::<Vec<_>>() - .await - // impl From would take more code; this is used once. - .map_err(ErrorResponse)?; - - // We resume counting from the provided last-event-id mapping, rather than - // starting from scratch, so that the events in a resumed stream contain - // the full vector of channel IDs for their event IDs right off the bat, - // even before any events are actually delivered. - let stream = stream::select_all(streams).scan(resume_at, |sequences, event| { - let (channel, sequence) = event.event_id(); - sequences.insert(channel, sequence); - - let event = ReplayableEvent(sequences.clone(), event); - - future::ready(Some(event)) - }); + let stream = app.events().subscribe(resume_at).await?; Ok(Events(stream)) } @@ -100,7 +41,7 @@ struct Events<S>(S); impl<S> IntoResponse for Events<S> where - S: Stream<Item = ReplayableEvent> + Send + 'static, + S: Stream<Item = types::ResumableEvent> + Send + 'static, { fn into_response(self) -> Response { let Self(stream) = self; @@ -111,51 +52,13 @@ where } } -#[derive(Debug)] -struct ErrorResponse(EventsError); - -impl IntoResponse for ErrorResponse { - fn into_response(self) -> Response { - let Self(error) = self; - match error { - not_found @ EventsError::ChannelNotFound(_) => { - (StatusCode::NOT_FOUND, not_found.to_string()).into_response() - } - other => Internal::from(other).into_response(), - } - } -} - -#[derive(Debug)] -struct ReplayableEvent(EventId, ChannelEvent); - -#[derive(Debug, serde::Serialize)] -struct ChannelEvent { - channel: channel::Id, - #[serde(flatten)] - message: broadcast::Message, -} - -impl ChannelEvent { - fn wrap(channel: channel::Id) -> impl Fn(broadcast::Message) -> Self { - move |message| Self { - channel: channel.clone(), - message, - } - } - - fn event_id(&self) -> (channel::Id, broadcast::Sequence) { - (self.channel.clone(), self.message.sequence) - } -} - -impl TryFrom<ReplayableEvent> for sse::Event { +impl TryFrom<types::ResumableEvent> for sse::Event { type Error = serde_json::Error; - fn try_from(value: ReplayableEvent) -> Result<Self, Self::Error> { - let ReplayableEvent(id, data) = value; + fn try_from(value: types::ResumableEvent) -> Result<Self, Self::Error> { + let types::ResumableEvent(resume_at, data) = value; - let id = serde_json::to_string(&id)?; + let id = serde_json::to_string(&resume_at)?; let data = serde_json::to_string_pretty(&data)?; let event = Self::default().id(id).data(data); diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs index 0b08fd6..a6e2275 100644 --- a/src/events/routes/test.rs +++ b/src/events/routes/test.rs @@ -1,70 +1,40 @@ use axum::extract::State; -use axum_extra::extract::Query; use futures::{ future, stream::{self, StreamExt as _}, }; use crate::{ - events::{app, routes}, - repo::channel::{self}, + events::{routes, types}, test::fixtures::{self, future::Immediately as _}, }; #[tokio::test] -async fn no_subscriptions() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let subscriber = fixtures::login::create(&app).await; - - // Call the endpoint - - let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [].into(), - }; - let routes::Events(mut events) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect("empty subscription"); - - // Verify the structure of the response. - - assert!(events.next().immediately().await.is_none()); -} - -#[tokio::test] async fn includes_historical_message() { // Set up the environment let app = fixtures::scratch_app().await; let sender = fixtures::login::create(&app).await; - let channel = fixtures::channel::create(&app).await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; // Call the endpoint let subscriber = fixtures::login::create(&app).await; - let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [channel.id.clone()].into(), - }; - let routes::Events(mut events) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect("subscribed to valid channel"); + let routes::Events(events) = routes::events(State(app), subscriber, None) + .await + .expect("subscribe never fails"); // Verify the structure of the response. - let routes::ReplayableEvent(_, event) = events + let types::ResumableEvent(_, event) = events + .filter(fixtures::filter::messages()) .next() .immediately() .await .expect("delivered stored message"); - assert_eq!(channel.id, event.channel); - assert_eq!(message, event.message); + assert_eq!(message, event); } #[tokio::test] @@ -72,74 +42,28 @@ async fn includes_live_message() { // Set up the environment let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app).await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; // Call the endpoint let subscriber = fixtures::login::create(&app).await; - let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [channel.id.clone()].into(), - }; - let routes::Events(mut events) = routes::events( - State(app.clone()), - subscribed_at, - subscriber, - None, - Query(query), - ) - .await - .expect("subscribed to a valid channel"); + let routes::Events(events) = routes::events(State(app.clone()), subscriber, None) + .await + .expect("subscribe never fails"); // Verify the semantics let sender = fixtures::login::create(&app).await; let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - let routes::ReplayableEvent(_, event) = events + let types::ResumableEvent(_, event) = events + .filter(fixtures::filter::messages()) .next() .immediately() .await .expect("delivered live message"); - assert_eq!(channel.id, event.channel); - assert_eq!(message, event.message); -} - -#[tokio::test] -async fn excludes_other_channels() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let subscribed_channel = fixtures::channel::create(&app).await; - let unsubscribed_channel = fixtures::channel::create(&app).await; - let sender = fixtures::login::create(&app).await; - let message = - fixtures::message::send(&app, &sender, &subscribed_channel, &fixtures::now()).await; - fixtures::message::send(&app, &sender, &unsubscribed_channel, &fixtures::now()).await; - - // Call the endpoint - - let subscriber = fixtures::login::create(&app).await; - let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [subscribed_channel.id.clone()].into(), - }; - let routes::Events(mut events) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect("subscribed to a valid channel"); - - // Verify the semantics - - let routes::ReplayableEvent(_, event) = events - .next() - .immediately() - .await - .expect("delivered at least one message"); - - assert_eq!(subscribed_channel.id, event.channel); - assert_eq!(message, event.message); + assert_eq!(message, event); } #[tokio::test] @@ -150,15 +74,16 @@ async fn includes_multiple_channels() { let sender = fixtures::login::create(&app).await; let channels = [ - fixtures::channel::create(&app).await, - fixtures::channel::create(&app).await, + fixtures::channel::create(&app, &fixtures::now()).await, + fixtures::channel::create(&app, &fixtures::now()).await, ]; let messages = stream::iter(channels) - .then(|channel| async { - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - - (channel, message) + .then(|channel| { + let app = app.clone(); + let sender = sender.clone(); + let channel = channel.clone(); + async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await } }) .collect::<Vec<_>>() .await; @@ -166,68 +91,32 @@ async fn includes_multiple_channels() { // Call the endpoint let subscriber = fixtures::login::create(&app).await; - let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: messages - .iter() - .map(|(channel, _)| &channel.id) - .cloned() - .collect(), - }; - let routes::Events(events) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect("subscribed to valid channels"); + let routes::Events(events) = routes::events(State(app), subscriber, None) + .await + .expect("subscribe never fails"); // Verify the structure of the response. let events = events + .filter(fixtures::filter::messages()) .take(messages.len()) .collect::<Vec<_>>() .immediately() .await; - for (channel, message) in messages { - assert!(events.iter().any(|routes::ReplayableEvent(_, event)| { - event.channel == channel.id && event.message == message - })); + for message in &messages { + assert!(events + .iter() + .any(|types::ResumableEvent(_, event)| { event == message })); } } #[tokio::test] -async fn nonexistent_channel() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let channel = channel::Id::generate(); - - // Call the endpoint - - let subscriber = fixtures::login::create(&app).await; - let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [channel.clone()].into(), - }; - let routes::ErrorResponse(error) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect_err("subscribed to nonexistent channel"); - - // Verify the structure of the response. - - fixtures::error::expected!( - error, - app::EventsError::ChannelNotFound(error_channel), - assert_eq!(channel, error_channel) - ); -} - -#[tokio::test] async fn sequential_messages() { // Set up the environment let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app).await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; let sender = fixtures::login::create(&app).await; let messages = vec![ @@ -239,31 +128,24 @@ async fn sequential_messages() { // Call the endpoint let subscriber = fixtures::login::create(&app).await; - let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [channel.id.clone()].into(), - }; - let routes::Events(events) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect("subscribed to a valid channel"); + let routes::Events(events) = routes::events(State(app), subscriber, None) + .await + .expect("subscribe never fails"); // Verify the structure of the response. - let mut events = events.filter(|routes::ReplayableEvent(_, event)| { - future::ready(messages.contains(&event.message)) - }); + let mut events = + events.filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))); // Verify delivery in order for message in &messages { - let routes::ReplayableEvent(_, event) = events + let types::ResumableEvent(_, event) = events .next() .immediately() .await .expect("undelivered messages remaining"); - assert_eq!(channel.id, event.channel); - assert_eq!(message, &event.message); + assert_eq!(message, &event); } } @@ -272,7 +154,7 @@ async fn resumes_from() { // Set up the environment let app = fixtures::scratch_app().await; - let channel = fixtures::channel::create(&app).await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; let sender = fixtures::login::create(&app).await; let initial_message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; @@ -285,43 +167,29 @@ async fn resumes_from() { // Call the endpoint let subscriber = fixtures::login::create(&app).await; - let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [channel.id.clone()].into(), - }; let resume_at = { // First subscription - let routes::Events(mut events) = routes::events( - State(app.clone()), - subscribed_at, - subscriber.clone(), - None, - Query(query.clone()), - ) - .await - .expect("subscribed to a valid channel"); + let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None) + .await + .expect("subscribe never fails"); - let routes::ReplayableEvent(id, event) = - events.next().immediately().await.expect("delivered events"); + let types::ResumableEvent(last_event_id, event) = events + .filter(fixtures::filter::messages()) + .next() + .immediately() + .await + .expect("delivered events"); - assert_eq!(channel.id, event.channel); - assert_eq!(initial_message, event.message); + assert_eq!(initial_message, event); - id + last_event_id }; // Resume after disconnect - let reconnect_at = fixtures::now(); - let routes::Events(resumed) = routes::events( - State(app), - reconnect_at, - subscriber, - Some(resume_at.into()), - Query(query), - ) - .await - .expect("subscribed to a valid channel"); + let routes::Events(resumed) = routes::events(State(app), subscriber, Some(resume_at.into())) + .await + .expect("subscribe never fails"); // Verify the structure of the response. @@ -331,11 +199,10 @@ async fn resumes_from() { .immediately() .await; - for message in later_messages { - assert!(events.iter().any( - |routes::ReplayableEvent(_, event)| event.channel == channel.id - && event.message == message - )); + for message in &later_messages { + assert!(events + .iter() + .any(|types::ResumableEvent(_, event)| event == message)); } } @@ -360,15 +227,12 @@ async fn serial_resume() { let app = fixtures::scratch_app().await; let sender = fixtures::login::create(&app).await; - let channel_a = fixtures::channel::create(&app).await; - let channel_b = fixtures::channel::create(&app).await; + let channel_a = fixtures::channel::create(&app, &fixtures::now()).await; + let channel_b = fixtures::channel::create(&app, &fixtures::now()).await; // Call the endpoint let subscriber = fixtures::login::create(&app).await; - let query = routes::EventsQuery { - channels: [channel_a.id.clone(), channel_b.id.clone()].into(), - }; let resume_at = { let initial_messages = [ @@ -377,30 +241,24 @@ async fn serial_resume() { ]; // First subscription - let subscribed_at = fixtures::now(); - let routes::Events(events) = routes::events( - State(app.clone()), - subscribed_at, - subscriber.clone(), - None, - Query(query.clone()), - ) - .await - .expect("subscribed to a valid channel"); + let routes::Events(events) = routes::events(State(app.clone()), subscriber.clone(), None) + .await + .expect("subscribe never fails"); let events = events + .filter(fixtures::filter::messages()) .take(initial_messages.len()) .collect::<Vec<_>>() .immediately() .await; - for message in initial_messages { + for message in &initial_messages { assert!(events .iter() - .any(|routes::ReplayableEvent(_, event)| event.message == message)); + .any(|types::ResumableEvent(_, event)| event == message)); } - let routes::ReplayableEvent(id, _) = events.last().expect("this vec is non-empty"); + let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty"); id.to_owned() }; @@ -416,30 +274,28 @@ async fn serial_resume() { ]; // Second subscription - let resubscribed_at = fixtures::now(); let routes::Events(events) = routes::events( State(app.clone()), - resubscribed_at, subscriber.clone(), Some(resume_at.into()), - Query(query.clone()), ) .await - .expect("subscribed to a valid channel"); + .expect("subscribe never fails"); let events = events + .filter(fixtures::filter::messages()) .take(resume_messages.len()) .collect::<Vec<_>>() .immediately() .await; - for message in resume_messages { + for message in &resume_messages { assert!(events .iter() - .any(|routes::ReplayableEvent(_, event)| event.message == message)); + .any(|types::ResumableEvent(_, event)| event == message)); } - let routes::ReplayableEvent(id, _) = events.last().expect("this vec is non-empty"); + let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty"); id.to_owned() }; @@ -454,19 +310,17 @@ async fn serial_resume() { fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, ]; - // Second subscription - let resubscribed_at = fixtures::now(); + // Third subscription let routes::Events(events) = routes::events( State(app.clone()), - resubscribed_at, subscriber.clone(), Some(resume_at.into()), - Query(query.clone()), ) .await - .expect("subscribed to a valid channel"); + .expect("subscribe never fails"); let events = events + .filter(fixtures::filter::messages()) .take(final_messages.len()) .collect::<Vec<_>>() .immediately() @@ -474,44 +328,10 @@ async fn serial_resume() { // This set of messages, in particular, _should not_ include any prior // messages from `initial_messages` or `resume_messages`. - for message in final_messages { + for message in &final_messages { assert!(events .iter() - .any(|routes::ReplayableEvent(_, event)| event.message == message)); + .any(|types::ResumableEvent(_, event)| event == message)); } }; } - -#[tokio::test] -async fn removes_expired_messages() { - // Set up the environment - let app = fixtures::scratch_app().await; - let sender = fixtures::login::create(&app).await; - let channel = fixtures::channel::create(&app).await; - - fixtures::message::send(&app, &sender, &channel, &fixtures::ancient()).await; - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - - // Call the endpoint - - let subscriber = fixtures::login::create(&app).await; - let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [channel.id.clone()].into(), - }; - let routes::Events(mut events) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect("subscribed to valid channel"); - - // Verify the semantics - - let routes::ReplayableEvent(_, event) = events - .next() - .immediately() - .await - .expect("delivered messages"); - - assert_eq!(channel.id, event.channel); - assert_eq!(message, event.message); -} diff --git a/src/events/types.rs b/src/events/types.rs new file mode 100644 index 0000000..d954512 --- /dev/null +++ b/src/events/types.rs @@ -0,0 +1,170 @@ +use std::collections::BTreeMap; + +use crate::{ + clock::DateTime, + repo::{ + channel::{self, Channel}, + login::Login, + message, + }, +}; + +#[derive( + Debug, + Default, + Eq, + Ord, + PartialEq, + PartialOrd, + Clone, + Copy, + serde::Serialize, + serde::Deserialize, + sqlx::Type, +)] +#[serde(transparent)] +#[sqlx(transparent)] +pub struct Sequence(i64); + +impl Sequence { + pub fn next(self) -> Self { + let Self(current) = self; + Self(current + 1) + } +} + +// For the purposes of event replay, a resume point is a vector of resume +// elements. A resume element associates a channel (by ID) with the latest event +// seen in that channel so far. Replaying the event stream can restart at a +// predictable point - hence the name. These values can be serialized and sent +// to the client as JSON dicts, then rehydrated to recover the resume point at a +// later time. +// +// Using a sorted map ensures that there is a canonical representation for +// each resume point. +#[derive(Clone, Debug, Default, PartialEq, PartialOrd, serde::Deserialize, serde::Serialize)] +#[serde(transparent)] +pub struct ResumePoint(BTreeMap<channel::Id, Sequence>); + +impl ResumePoint { + pub fn advance<'e>(&mut self, event: impl Into<ResumeElement<'e>>) { + let Self(elements) = self; + let ResumeElement(channel, sequence) = event.into(); + elements.insert(channel.clone(), sequence); + } + + pub fn forget<'e>(&mut self, event: impl Into<ResumeElement<'e>>) { + let Self(elements) = self; + let ResumeElement(channel, _) = event.into(); + elements.remove(channel); + } + + pub fn get(&self, channel: &channel::Id) -> Option<Sequence> { + let Self(elements) = self; + elements.get(channel).copied() + } + + pub fn not_after<'e>(&self, event: impl Into<ResumeElement<'e>>) -> bool { + let Self(elements) = self; + let ResumeElement(channel, sequence) = event.into(); + + elements + .get(channel) + .map_or(true, |resume_at| resume_at < &sequence) + } +} + +pub struct ResumeElement<'i>(&'i channel::Id, Sequence); + +#[derive(Clone, Debug)] +pub struct ResumableEvent(pub ResumePoint, pub ChannelEvent); + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct ChannelEvent { + #[serde(skip)] + pub sequence: Sequence, + pub at: DateTime, + #[serde(flatten)] + pub data: ChannelEventData, +} + +impl ChannelEvent { + pub fn created(channel: Channel) -> Self { + Self { + at: channel.created_at, + sequence: Sequence::default(), + data: CreatedEvent { channel }.into(), + } + } + + pub fn channel_id(&self) -> &channel::Id { + match &self.data { + ChannelEventData::Created(event) => &event.channel.id, + ChannelEventData::Message(event) => &event.channel.id, + ChannelEventData::MessageDeleted(event) => &event.channel.id, + ChannelEventData::Deleted(event) => &event.channel, + } + } +} + +impl<'c> From<&'c ChannelEvent> for ResumeElement<'c> { + fn from(event: &'c ChannelEvent) -> Self { + Self(event.channel_id(), event.sequence) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ChannelEventData { + Created(CreatedEvent), + Message(MessageEvent), + MessageDeleted(MessageDeletedEvent), + Deleted(DeletedEvent), +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct CreatedEvent { + pub channel: Channel, +} + +impl From<CreatedEvent> for ChannelEventData { + fn from(event: CreatedEvent) -> Self { + Self::Created(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct MessageEvent { + pub channel: Channel, + pub sender: Login, + pub message: message::Message, +} + +impl From<MessageEvent> for ChannelEventData { + fn from(event: MessageEvent) -> Self { + Self::Message(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct MessageDeletedEvent { + pub channel: Channel, + pub message: message::Id, +} + +impl From<MessageDeletedEvent> for ChannelEventData { + fn from(event: MessageDeletedEvent) -> Self { + Self::MessageDeleted(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct DeletedEvent { + pub channel: channel::Id, +} + +impl From<DeletedEvent> for ChannelEventData { + fn from(event: DeletedEvent) -> Self { + Self::Deleted(event) + } +} diff --git a/src/expire.rs b/src/expire.rs new file mode 100644 index 0000000..16006d1 --- /dev/null +++ b/src/expire.rs @@ -0,0 +1,20 @@ +use axum::{ + extract::{Request, State}, + middleware::Next, + response::Response, +}; + +use crate::{app::App, clock::RequestedAt, error::Internal}; + +// Expires messages and channels before each request. +pub async fn middleware( + State(app): State<App>, + RequestedAt(expired_at): RequestedAt, + req: Request, + next: Next, +) -> Result<Response, Internal> { + app.logins().expire(&expired_at).await?; + app.events().expire(&expired_at).await?; + app.channels().expire(&expired_at).await?; + Ok(next.run(req).await) +} @@ -8,6 +8,7 @@ pub mod cli; mod clock; mod error; mod events; +mod expire; mod id; mod login; mod password; diff --git a/src/login/app.rs b/src/login/app.rs index 10609c6..f7fec88 100644 --- a/src/login/app.rs +++ b/src/login/app.rs @@ -1,10 +1,10 @@ use chrono::TimeDelta; use sqlx::sqlite::SqlitePool; -use super::repo::auth::Provider as _; +use super::{extract::IdentitySecret, repo::auth::Provider as _}; use crate::{ clock::DateTime, - password::StoredHash, + password::Password, repo::{ error::NotFound as _, login::{Login, Provider as _}, @@ -24,9 +24,9 @@ impl<'a> Logins<'a> { pub async fn login( &self, name: &str, - password: &str, + password: &Password, login_at: &DateTime, - ) -> Result<String, LoginError> { + ) -> Result<IdentitySecret, LoginError> { let mut tx = self.db.begin().await?; let login = if let Some((login, stored_hash)) = tx.auth().for_name(name).await? { @@ -38,7 +38,7 @@ impl<'a> Logins<'a> { return Err(LoginError::Rejected); } } else { - let password_hash = StoredHash::new(password)?; + let password_hash = password.hash()?; tx.logins().create(name, &password_hash).await? }; @@ -49,8 +49,8 @@ impl<'a> Logins<'a> { } #[cfg(test)] - pub async fn create(&self, name: &str, password: &str) -> Result<Login, CreateError> { - let password_hash = StoredHash::new(password)?; + pub async fn create(&self, name: &str, password: &Password) -> Result<Login, CreateError> { + let password_hash = password.hash()?; let mut tx = self.db.begin().await?; let login = tx.logins().create(name, &password_hash).await?; @@ -59,12 +59,12 @@ impl<'a> Logins<'a> { Ok(login) } - pub async fn validate(&self, secret: &str, used_at: &DateTime) -> Result<Login, ValidateError> { - // Somewhat arbitrarily, expire after 7 days. - let expire_at = used_at.to_owned() - TimeDelta::days(7); - + pub async fn validate( + &self, + secret: &IdentitySecret, + used_at: &DateTime, + ) -> Result<Login, ValidateError> { let mut tx = self.db.begin().await?; - tx.tokens().expire(&expire_at).await?; let login = tx .tokens() .validate(secret, used_at) @@ -75,7 +75,18 @@ impl<'a> Logins<'a> { Ok(login) } - pub async fn logout(&self, secret: &str) -> Result<(), ValidateError> { + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { + // Somewhat arbitrarily, expire after 7 days. + let expire_at = relative_to.to_owned() - TimeDelta::days(7); + + let mut tx = self.db.begin().await?; + tx.tokens().expire(&expire_at).await?; + tx.commit().await?; + + Ok(()) + } + + pub async fn logout(&self, secret: &IdentitySecret) -> Result<(), ValidateError> { let mut tx = self.db.begin().await?; tx.tokens() .revoke(secret) diff --git a/src/login/extract.rs b/src/login/extract.rs index 5ef454c..3b31d4c 100644 --- a/src/login/extract.rs +++ b/src/login/extract.rs @@ -1,3 +1,5 @@ +use std::fmt; + use axum::{ extract::FromRequestParts, http::request::Parts, @@ -7,11 +9,22 @@ use axum_extra::extract::cookie::{Cookie, CookieJar}; // The usage pattern here - receive the extractor as an argument, return it in // the response - is heavily modelled after CookieJar's own intended usage. -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct IdentityToken { cookies: CookieJar, } +impl fmt::Debug for IdentityToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IdentityToken") + .field( + "identity", + &self.cookies.get(IDENTITY_COOKIE).map(|_| "********"), + ) + .finish() + } +} + impl IdentityToken { // Creates a new, unpopulated identity token store. #[cfg(test)] @@ -26,14 +39,18 @@ impl IdentityToken { // return [None]. If the identity has previously been [set], then this // will return that secret, regardless of what the request originally // included. - pub fn secret(&self) -> Option<&str> { - self.cookies.get(IDENTITY_COOKIE).map(Cookie::value) + pub fn secret(&self) -> Option<IdentitySecret> { + self.cookies + .get(IDENTITY_COOKIE) + .map(Cookie::value) + .map(IdentitySecret::from) } // Positively set the identity secret, and ensure that it will be sent // back to the client when this extractor is included in a response. - pub fn set(self, secret: &str) -> Self { - let identity_cookie = Cookie::build((IDENTITY_COOKIE, String::from(secret))) + pub fn set(self, secret: impl Into<IdentitySecret>) -> Self { + let IdentitySecret(secret) = secret.into(); + let identity_cookie = Cookie::build((IDENTITY_COOKIE, secret)) .http_only(true) .path("/api/") .permanent() @@ -76,3 +93,22 @@ impl IntoResponseParts for IdentityToken { cookies.into_response_parts(res) } } + +#[derive(sqlx::Type)] +#[sqlx(transparent)] +pub struct IdentitySecret(String); + +impl fmt::Debug for IdentitySecret { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("IdentityToken").field(&"********").finish() + } +} + +impl<S> From<S> for IdentitySecret +where + S: Into<String>, +{ + fn from(value: S) -> Self { + Self(value.into()) + } +} diff --git a/src/login/routes.rs b/src/login/routes.rs index 31a68d0..4664063 100644 --- a/src/login/routes.rs +++ b/src/login/routes.rs @@ -6,7 +6,9 @@ use axum::{ Router, }; -use crate::{app::App, clock::RequestedAt, error::Internal, repo::login::Login}; +use crate::{ + app::App, clock::RequestedAt, error::Internal, password::Password, repo::login::Login, +}; use super::{app, extract::IdentityToken}; @@ -38,7 +40,7 @@ impl IntoResponse for Boot { #[derive(serde::Deserialize)] struct LoginRequest { name: String, - password: String, + password: Password, } async fn on_login( @@ -52,7 +54,7 @@ async fn on_login( .login(&request.name, &request.password, &now) .await .map_err(LoginError)?; - let identity = identity.set(&token); + let identity = identity.set(token); Ok((identity, StatusCode::NO_CONTENT)) } @@ -82,7 +84,7 @@ async fn on_logout( Json(LogoutRequest {}): Json<LogoutRequest>, ) -> Result<(IdentityToken, StatusCode), LogoutError> { if let Some(secret) = identity.secret() { - app.logins().logout(secret).await.map_err(LogoutError)?; + app.logins().logout(&secret).await.map_err(LogoutError)?; } let identity = identity.clear(); diff --git a/src/login/routes/test/login.rs b/src/login/routes/test/login.rs index d92c01b..10c17d6 100644 --- a/src/login/routes/test/login.rs +++ b/src/login/routes/test/login.rs @@ -38,7 +38,7 @@ async fn new_identity() { let validated_at = fixtures::now(); let validated = app .logins() - .validate(secret, &validated_at) + .validate(&secret, &validated_at) .await .expect("identity secret is valid"); @@ -75,7 +75,7 @@ async fn existing_identity() { let validated_at = fixtures::now(); let validated_login = app .logins() - .validate(secret, &validated_at) + .validate(&secret, &validated_at) .await .expect("identity secret is valid"); @@ -122,14 +122,20 @@ async fn token_expires() { let (identity, _) = routes::on_login(State(app.clone()), logged_in_at, identity, Json(request)) .await .expect("logged in with valid credentials"); - let token = identity.secret().expect("logged in with valid credentials"); + let secret = identity.secret().expect("logged in with valid credentials"); // Verify the semantics + let expired_at = fixtures::now(); + app.logins() + .expire(&expired_at) + .await + .expect("expiring tokens never fails"); + let verified_at = fixtures::now(); let error = app .logins() - .validate(token, &verified_at) + .validate(&secret, &verified_at) .await .expect_err("validating an expired token"); diff --git a/src/login/routes/test/logout.rs b/src/login/routes/test/logout.rs index 4c09a73..05594be 100644 --- a/src/login/routes/test/logout.rs +++ b/src/login/routes/test/logout.rs @@ -37,7 +37,7 @@ async fn successful() { let error = app .logins() - .validate(secret, &now) + .validate(&secret, &now) .await .expect_err("secret is invalid"); match error { diff --git a/src/password.rs b/src/password.rs index b14f728..da3930f 100644 --- a/src/password.rs +++ b/src/password.rs @@ -1,3 +1,5 @@ +use std::fmt; + use argon2::Argon2; use password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}; use rand_core::OsRng; @@ -7,16 +9,7 @@ use rand_core::OsRng; pub struct StoredHash(String); impl StoredHash { - pub fn new(password: &str) -> Result<Self, password_hash::Error> { - let salt = SaltString::generate(&mut OsRng); - let argon2 = Argon2::default(); - let hash = argon2 - .hash_password(password.as_bytes(), &salt)? - .to_string(); - Ok(Self(hash)) - } - - pub fn verify(&self, password: &str) -> Result<bool, password_hash::Error> { + pub fn verify(&self, password: &Password) -> Result<bool, password_hash::Error> { let hash = PasswordHash::new(&self.0)?; match Argon2::default().verify_password(password.as_bytes(), &hash) { @@ -29,3 +22,37 @@ impl StoredHash { } } } + +#[derive(serde::Deserialize)] +#[serde(transparent)] +pub struct Password(String); + +impl Password { + pub fn hash(&self) -> Result<StoredHash, password_hash::Error> { + let Self(password) = self; + let salt = SaltString::generate(&mut OsRng); + let argon2 = Argon2::default(); + let hash = argon2 + .hash_password(password.as_bytes(), &salt)? + .to_string(); + Ok(StoredHash(hash)) + } + + fn as_bytes(&self) -> &[u8] { + let Self(value) = self; + value.as_bytes() + } +} + +impl fmt::Debug for Password { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Password").field(&"********").finish() + } +} + +#[cfg(test)] +impl From<String> for Password { + fn from(password: String) -> Self { + Self(password) + } +} diff --git a/src/repo/channel.rs b/src/repo/channel.rs index 0186413..3c7468f 100644 --- a/src/repo/channel.rs +++ b/src/repo/channel.rs @@ -2,7 +2,11 @@ use std::fmt; use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; -use crate::id::Id as BaseId; +use crate::{ + clock::DateTime, + events::types::{self, Sequence}, + id::Id as BaseId, +}; pub trait Provider { fn channels(&mut self) -> Channels; @@ -16,26 +20,38 @@ impl<'c> Provider for Transaction<'c, Sqlite> { pub struct Channels<'t>(&'t mut SqliteConnection); -#[derive(Debug, Eq, PartialEq, serde::Serialize)] +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct Channel { pub id: Id, pub name: String, + #[serde(skip)] + pub created_at: DateTime, } impl<'c> Channels<'c> { - pub async fn create(&mut self, name: &str) -> Result<Channel, sqlx::Error> { + pub async fn create( + &mut self, + name: &str, + created_at: &DateTime, + ) -> Result<Channel, sqlx::Error> { let id = Id::generate(); + let sequence = Sequence::default(); let channel = sqlx::query_as!( Channel, r#" insert - into channel (id, name) - values ($1, $2) - returning id as "id: Id", name + into channel (id, name, created_at, last_sequence) + values ($1, $2, $3, $4) + returning + id as "id: Id", + name, + created_at as "created_at: DateTime" "#, id, name, + created_at, + sequence, ) .fetch_one(&mut *self.0) .await?; @@ -47,7 +63,10 @@ impl<'c> Channels<'c> { let channel = sqlx::query_as!( Channel, r#" - select id as "id: Id", name + select + id as "id: Id", + name, + created_at as "created_at: DateTime" from channel where id = $1 "#, @@ -64,8 +83,9 @@ impl<'c> Channels<'c> { Channel, r#" select - channel.id as "id: Id", - channel.name + id as "id: Id", + name, + created_at as "created_at: DateTime" from channel order by channel.name "#, @@ -75,6 +95,52 @@ impl<'c> Channels<'c> { Ok(channels) } + + pub async fn delete_expired( + &mut self, + channel: &Channel, + sequence: Sequence, + deleted_at: &DateTime, + ) -> Result<types::ChannelEvent, sqlx::Error> { + let channel = channel.id.clone(); + sqlx::query_scalar!( + r#" + delete from channel + where id = $1 + returning 1 as "row: i64" + "#, + channel, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(types::ChannelEvent { + sequence, + at: *deleted_at, + data: types::DeletedEvent { channel }.into(), + }) + } + + pub async fn expired(&mut self, expired_at: &DateTime) -> Result<Vec<Channel>, sqlx::Error> { + let channels = sqlx::query_as!( + Channel, + r#" + select + channel.id as "id: Id", + channel.name, + channel.created_at as "created_at: DateTime" + from channel + left join message + where created_at < $1 + and message.id is null + "#, + expired_at, + ) + .fetch_all(&mut *self.0) + .await?; + + Ok(channels) + } } // Stable identifier for a [Channel]. Prefixed with `C`. diff --git a/src/repo/login/extract.rs b/src/repo/login/extract.rs index e5f96d0..c127078 100644 --- a/src/repo/login/extract.rs +++ b/src/repo/login/extract.rs @@ -32,7 +32,7 @@ impl FromRequestParts<App> for Login { let secret = identity_token.secret().ok_or(LoginError::Unauthorized)?; let app = State::<App>::from_request_parts(parts, state).await?; - match app.logins().validate(secret, &used_at).await { + match app.logins().validate(&secret, &used_at).await { Ok(login) => Ok(login), Err(ValidateError::InvalidToken) => Err(LoginError::Unauthorized), Err(other) => Err(other.into()), diff --git a/src/repo/message.rs b/src/repo/message.rs index 385b103..a1f73d5 100644 --- a/src/repo/message.rs +++ b/src/repo/message.rs @@ -25,3 +25,9 @@ impl fmt::Display for Id { self.0.fmt(f) } } + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct Message { + pub id: Id, + pub body: String, +} diff --git a/src/repo/token.rs b/src/repo/token.rs index 8276bea..15eef48 100644 --- a/src/repo/token.rs +++ b/src/repo/token.rs @@ -2,7 +2,7 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use uuid::Uuid; use super::login::{self, Login}; -use crate::clock::DateTime; +use crate::{clock::DateTime, login::extract::IdentitySecret}; pub trait Provider { fn tokens(&mut self) -> Tokens; @@ -23,7 +23,7 @@ impl<'c> Tokens<'c> { &mut self, login: &Login, issued_at: &DateTime, - ) -> Result<String, sqlx::Error> { + ) -> Result<IdentitySecret, sqlx::Error> { let secret = Uuid::new_v4().to_string(); let secret = sqlx::query_scalar!( @@ -31,7 +31,7 @@ impl<'c> Tokens<'c> { insert into token (secret, login, issued_at, last_used_at) values ($1, $2, $3, $3) - returning secret as "secret!" + returning secret as "secret!: IdentitySecret" "#, secret, login.id, @@ -44,7 +44,7 @@ impl<'c> Tokens<'c> { } // Revoke a token by its secret. - pub async fn revoke(&mut self, secret: &str) -> Result<(), sqlx::Error> { + pub async fn revoke(&mut self, secret: &IdentitySecret) -> Result<(), sqlx::Error> { sqlx::query!( r#" delete @@ -82,7 +82,7 @@ impl<'c> Tokens<'c> { // timestamp will be set to `used_at`. pub async fn validate( &mut self, - secret: &str, + secret: &IdentitySecret, used_at: &DateTime, ) -> Result<Login, sqlx::Error> { // I would use `update … returning` to do this in one query, but diff --git a/src/test/fixtures/channel.rs b/src/test/fixtures/channel.rs index 0558395..8744470 100644 --- a/src/test/fixtures/channel.rs +++ b/src/test/fixtures/channel.rs @@ -4,12 +4,12 @@ use faker_rand::{ }; use rand; -use crate::{app::App, repo::channel::Channel}; +use crate::{app::App, clock::RequestedAt, repo::channel::Channel}; -pub async fn create(app: &App) -> Channel { +pub async fn create(app: &App, created_at: &RequestedAt) -> Channel { let name = propose(); app.channels() - .create(&name) + .create(&name, created_at) .await .expect("should always succeed if the channel is actually new") } diff --git a/src/test/fixtures/filter.rs b/src/test/fixtures/filter.rs new file mode 100644 index 0000000..fbebced --- /dev/null +++ b/src/test/fixtures/filter.rs @@ -0,0 +1,15 @@ +use futures::future; + +use crate::events::types; + +pub fn messages() -> impl FnMut(&types::ResumableEvent) -> future::Ready<bool> { + |types::ResumableEvent(_, event)| { + future::ready(matches!(event.data, types::ChannelEventData::Message(_))) + } +} + +pub fn created() -> impl FnMut(&types::ResumableEvent) -> future::Ready<bool> { + |types::ResumableEvent(_, event)| { + future::ready(matches!(event.data, types::ChannelEventData::Created(_))) + } +} diff --git a/src/test/fixtures/identity.rs b/src/test/fixtures/identity.rs index 16463aa..69b5f4c 100644 --- a/src/test/fixtures/identity.rs +++ b/src/test/fixtures/identity.rs @@ -1,12 +1,17 @@ use uuid::Uuid; -use crate::{app::App, clock::RequestedAt, login::extract::IdentityToken}; +use crate::{ + app::App, + clock::RequestedAt, + login::extract::{IdentitySecret, IdentityToken}, + password::Password, +}; pub fn not_logged_in() -> IdentityToken { IdentityToken::new() } -pub async fn logged_in(app: &App, login: &(String, String), now: &RequestedAt) -> IdentityToken { +pub async fn logged_in(app: &App, login: &(String, Password), now: &RequestedAt) -> IdentityToken { let (name, password) = login; let token = app .logins() @@ -14,14 +19,14 @@ pub async fn logged_in(app: &App, login: &(String, String), now: &RequestedAt) - .await .expect("should succeed given known-valid credentials"); - IdentityToken::new().set(&token) + IdentityToken::new().set(token) } -pub fn secret(identity: &IdentityToken) -> &str { +pub fn secret(identity: &IdentityToken) -> IdentitySecret { identity.secret().expect("identity contained a secret") } pub fn fictitious() -> IdentityToken { let token = Uuid::new_v4().to_string(); - IdentityToken::new().set(&token) + IdentityToken::new().set(token) } diff --git a/src/test/fixtures/login.rs b/src/test/fixtures/login.rs index f1e4b15..d6a321b 100644 --- a/src/test/fixtures/login.rs +++ b/src/test/fixtures/login.rs @@ -3,10 +3,11 @@ use uuid::Uuid; use crate::{ app::App, + password::Password, repo::login::{self, Login}, }; -pub async fn create_with_password(app: &App) -> (String, String) { +pub async fn create_with_password(app: &App) -> (String, Password) { let (name, password) = propose(); app.logins() .create(&name, &password) @@ -31,7 +32,7 @@ pub fn fictitious() -> Login { } } -pub fn propose() -> (String, String) { +pub fn propose() -> (String, Password) { (name(), propose_password()) } @@ -39,6 +40,6 @@ fn name() -> String { rand::random::<internet::Username>().to_string() } -pub fn propose_password() -> String { - Uuid::new_v4().to_string() +pub fn propose_password() -> Password { + Uuid::new_v4().to_string().into() } diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs index 33feeae..bfca8cd 100644 --- a/src/test/fixtures/message.rs +++ b/src/test/fixtures/message.rs @@ -3,7 +3,7 @@ use faker_rand::lorem::Paragraphs; use crate::{ app::App, clock::RequestedAt, - events::repo::broadcast, + events::types, repo::{channel::Channel, login::Login}, }; @@ -12,7 +12,7 @@ pub async fn send( login: &Login, channel: &Channel, sent_at: &RequestedAt, -) -> broadcast::Message { +) -> types::ChannelEvent { let body = propose(); app.events() diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs index a42dba5..d1dd0c3 100644 --- a/src/test/fixtures/mod.rs +++ b/src/test/fixtures/mod.rs @@ -3,6 +3,7 @@ use chrono::{TimeDelta, Utc}; use crate::{app::App, clock::RequestedAt, repo::pool}; pub mod channel; +pub mod filter; pub mod future; pub mod identity; pub mod login; @@ -13,8 +14,6 @@ pub async fn scratch_app() -> App { .await .expect("setting up in-memory sqlite database"); App::from(pool) - .await - .expect("creating an app from a fresh, in-memory database") } pub fn now() -> RequestedAt { |
