diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/app.rs | 8 | ||||
| -rw-r--r-- | src/channel/app.rs | 11 | ||||
| -rw-r--r-- | src/channel/routes/test/on_send.rs | 89 | ||||
| -rw-r--r-- | src/cli.rs | 2 | ||||
| -rw-r--r-- | src/events/app.rs | 93 | ||||
| -rw-r--r-- | src/events/broadcaster.rs | 77 | ||||
| -rw-r--r-- | src/events/mod.rs | 1 | ||||
| -rw-r--r-- | src/events/repo/message.rs (renamed from src/events/repo/broadcast.rs) | 83 | ||||
| -rw-r--r-- | src/events/repo/mod.rs | 2 | ||||
| -rw-r--r-- | src/events/routes.rs | 124 | ||||
| -rw-r--r-- | src/events/routes/test.rs | 276 | ||||
| -rw-r--r-- | src/events/types.rs | 99 | ||||
| -rw-r--r-- | src/repo/channel.rs | 2 | ||||
| -rw-r--r-- | src/test/fixtures/message.rs | 4 | ||||
| -rw-r--r-- | src/test/fixtures/mod.rs | 2 |
15 files changed, 309 insertions, 564 deletions
@@ -13,9 +13,9 @@ pub struct App { } impl App { - pub async fn from(db: SqlitePool) -> Result<Self, sqlx::Error> { - let broadcaster = Broadcaster::from_database(&db).await?; - Ok(Self { db, broadcaster }) + pub fn from(db: SqlitePool) -> Self { + let broadcaster = Broadcaster::default(); + Self { db, broadcaster } } } @@ -29,6 +29,6 @@ impl App { } pub const fn channels(&self) -> Channels { - Channels::new(&self.db, &self.broadcaster) + Channels::new(&self.db) } } diff --git a/src/channel/app.rs b/src/channel/app.rs index 793fa35..6bad158 100644 --- a/src/channel/app.rs +++ b/src/channel/app.rs @@ -1,18 +1,14 @@ use sqlx::sqlite::SqlitePool; -use crate::{ - events::broadcaster::Broadcaster, - repo::channel::{Channel, Provider as _}, -}; +use crate::repo::channel::{Channel, Provider as _}; pub struct Channels<'a> { db: &'a SqlitePool, - broadcaster: &'a Broadcaster, } impl<'a> Channels<'a> { - pub const fn new(db: &'a SqlitePool, broadcaster: &'a Broadcaster) -> Self { - Self { db, broadcaster } + pub const fn new(db: &'a SqlitePool) -> Self { + Self { db } } pub async fn create(&self, name: &str) -> Result<Channel, CreateError> { @@ -22,7 +18,6 @@ impl<'a> Channels<'a> { .create(name) .await .map_err(|err| CreateError::from_duplicate_name(err, name))?; - self.broadcaster.register_channel(&channel.id); tx.commit().await?; Ok(channel) diff --git a/src/channel/routes/test/on_send.rs b/src/channel/routes/test/on_send.rs index 93a5480..5d87bdc 100644 --- a/src/channel/routes/test/on_send.rs +++ b/src/channel/routes/test/on_send.rs @@ -1,65 +1,14 @@ -use axum::{ - extract::{Json, Path, State}, - http::StatusCode, -}; +use axum::extract::{Json, Path, State}; use futures::stream::StreamExt; use crate::{ channel::routes, - events::app, + events::{app, types}, repo::channel, test::fixtures::{self, future::Immediately as _}, }; #[tokio::test] -async fn channel_exists() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let sender = fixtures::login::create(&app).await; - let channel = fixtures::channel::create(&app).await; - - // Call the endpoint - - let sent_at = fixtures::now(); - let request = routes::SendRequest { - message: fixtures::message::propose(), - }; - let status = routes::on_send( - State(app.clone()), - Path(channel.id.clone()), - sent_at.clone(), - sender.clone(), - Json(request.clone()), - ) - .await - .expect("sending to a valid channel"); - - // Verify the structure of the response - - assert_eq!(StatusCode::ACCEPTED, status); - - // Verify the semantics - - let subscribed_at = fixtures::now(); - let mut events = app - .events() - .subscribe(&channel.id, &subscribed_at, None) - .await - .expect("subscribing to a valid channel"); - - let event = events - .next() - .immediately() - .await - .expect("event received by subscribers"); - - assert_eq!(request.message, event.body); - assert_eq!(sender, event.sender); - assert_eq!(*sent_at, event.sent_at); -} - -#[tokio::test] async fn messages_in_order() { // Set up the environment @@ -70,21 +19,15 @@ async fn messages_in_order() { // Call the endpoint (twice) let requests = vec![ - ( - fixtures::now(), - routes::SendRequest { - message: fixtures::message::propose(), - }, - ), - ( - fixtures::now(), - routes::SendRequest { - message: fixtures::message::propose(), - }, - ), + (fixtures::now(), fixtures::message::propose()), + (fixtures::now(), fixtures::message::propose()), ]; - for (sent_at, request) in &requests { + for (sent_at, message) in &requests { + let request = routes::SendRequest { + message: message.clone(), + }; + routes::on_send( State(app.clone()), Path(channel.id.clone()), @@ -101,17 +44,21 @@ async fn messages_in_order() { let subscribed_at = fixtures::now(); let events = app .events() - .subscribe(&channel.id, &subscribed_at, None) + .subscribe(&subscribed_at, types::ResumePoint::default()) .await .expect("subscribing to a valid channel") .take(requests.len()); let events = events.collect::<Vec<_>>().immediately().await; - for ((sent_at, request), event) in requests.into_iter().zip(events) { - assert_eq!(request.message, event.body); - assert_eq!(sender, event.sender); - assert_eq!(*sent_at, event.sent_at); + for ((sent_at, message), types::ResumableEvent(_, event)) in requests.into_iter().zip(events) { + assert_eq!(*sent_at, event.at); + assert!(matches!( + event.data, + types::ChannelEventData::Message(event_message) + if event_message.sender == sender + && event_message.body == message + )); } } @@ -70,7 +70,7 @@ impl Args { pub async fn run(self) -> Result<(), Error> { let pool = self.pool().await?; - let app = App::from(pool).await?; + let app = App::from(pool); let app = routers() .route_layer(middleware::from_fn(clock::middleware)) .with_state(app); diff --git a/src/events/app.rs b/src/events/app.rs index 7229551..043a29b 100644 --- a/src/events/app.rs +++ b/src/events/app.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use chrono::TimeDelta; use futures::{ future, @@ -8,7 +10,8 @@ use sqlx::sqlite::SqlitePool; use super::{ broadcaster::Broadcaster, - repo::broadcast::{self, Provider as _}, + repo::message::Provider as _, + types::{self, ResumePoint}, }; use crate::{ clock::DateTime, @@ -35,64 +38,56 @@ impl<'a> Events<'a> { channel: &channel::Id, body: &str, sent_at: &DateTime, - ) -> Result<broadcast::Message, EventsError> { + ) -> Result<types::ChannelEvent, EventsError> { let mut tx = self.db.begin().await?; let channel = tx .channels() .by_id(channel) .await .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; - let message = tx - .broadcast() + let event = tx + .message_events() .create(login, &channel, body, sent_at) .await?; tx.commit().await?; - self.broadcaster.broadcast(&channel.id, &message); - Ok(message) + self.broadcaster.broadcast(&event); + Ok(event) } pub async fn subscribe( &self, - channel: &channel::Id, subscribed_at: &DateTime, - resume_at: Option<broadcast::Sequence>, - ) -> Result<impl Stream<Item = broadcast::Message> + std::fmt::Debug, EventsError> { + resume_at: ResumePoint, + ) -> Result<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug, sqlx::Error> { // Somewhat arbitrarily, expire after 90 days. let expire_at = subscribed_at.to_owned() - TimeDelta::days(90); let mut tx = self.db.begin().await?; - let channel = tx - .channels() - .by_id(channel) - .await - .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; + let channels = tx.channels().all().await?; // Subscribe before retrieving, to catch messages broadcast while we're // querying the DB. We'll prune out duplicates later. - let live_messages = self.broadcaster.subscribe(&channel.id); + let live_messages = self.broadcaster.subscribe(); - tx.broadcast().expire(&expire_at).await?; - let stored_messages = tx.broadcast().replay(&channel, resume_at).await?; - tx.commit().await?; + tx.message_events().expire(&expire_at).await?; - let resume_broadcast_at = stored_messages - .last() - .map(|message| message.sequence) - .or(resume_at); + let mut replays = BTreeMap::new(); + let mut resume_live_at = resume_at.clone(); + for channel in channels { + let replay = tx + .message_events() + .replay(&channel, resume_at.get(&channel.id)) + .await?; - // This should always be the case, up to integer rollover, primarily - // because every message in stored_messages has a sequence not less - // than `resume_at`, or `resume_at` is None. We use the last message - // (if any) to decide when to resume the `live_messages` stream. - // - // It probably simplifies to assert!(resume_at <= resume_broadcast_at), but - // this form captures more of the reasoning. - assert!( - (resume_at.is_none() && resume_broadcast_at.is_none()) - || (stored_messages.is_empty() && resume_at == resume_broadcast_at) - || resume_at < resume_broadcast_at - ); + if let Some(last) = replay.last() { + resume_live_at.advance(&channel.id, last.sequence); + } + + replays.insert(channel.id.clone(), replay); + } + + let replay = stream::select_all(replays.into_values().map(stream::iter)); // no skip_expired or resume transforms for stored_messages, as it's // constructed not to contain messages meeting either criterion. @@ -100,7 +95,6 @@ impl<'a> Events<'a> { // * skip_expired is redundant with the `tx.broadcasts().expire(…)` call; // * resume is redundant with the resume_at argument to // `tx.broadcasts().replay(…)`. - let stored_messages = stream::iter(stored_messages); let live_messages = live_messages // Sure, it's temporally improbable that we'll ever skip a message // that's 90 days old, but there's no reason not to be thorough. @@ -108,26 +102,31 @@ impl<'a> Events<'a> { // Filtering on the broadcast resume point filters out messages // before resume_at, and filters out messages duplicated from // stored_messages. - .filter(Self::resume(resume_broadcast_at)); + .filter(Self::resume(resume_live_at)); - Ok(stored_messages.chain(live_messages)) + Ok(replay + .chain(live_messages) + .scan(resume_at, |resume_point, event| { + let channel = &event.channel.id; + let sequence = event.sequence; + resume_point.advance(channel, sequence); + + let event = types::ResumableEvent(resume_point.clone(), event); + + future::ready(Some(event)) + })) } fn resume( - resume_at: Option<broadcast::Sequence>, - ) -> impl for<'m> FnMut(&'m broadcast::Message) -> future::Ready<bool> { - move |msg| { - future::ready(match resume_at { - None => true, - Some(resume_at) => msg.sequence > resume_at, - }) - } + resume_at: ResumePoint, + ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready<bool> { + move |event| future::ready(resume_at < event.sequence()) } fn skip_expired( expire_at: &DateTime, - ) -> impl for<'m> FnMut(&'m broadcast::Message) -> future::Ready<bool> { + ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready<bool> { let expire_at = expire_at.to_owned(); - move |msg| future::ready(msg.sent_at > expire_at) + move |event| future::ready(expire_at < event.at) } } diff --git a/src/events/broadcaster.rs b/src/events/broadcaster.rs index dcaba91..9697c0a 100644 --- a/src/events/broadcaster.rs +++ b/src/events/broadcaster.rs @@ -1,63 +1,35 @@ -use std::collections::{hash_map::Entry, HashMap}; -use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::{Arc, Mutex}; use futures::{future, stream::StreamExt as _, Stream}; -use sqlx::sqlite::SqlitePool; use tokio::sync::broadcast::{channel, Sender}; use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}; -use crate::{ - events::repo::broadcast, - repo::channel::{self, Provider as _}, -}; +use crate::events::types; -// Clones will share the same senders collection. +// Clones will share the same sender. #[derive(Clone)] pub struct Broadcaster { // The use of std::sync::Mutex, and not tokio::sync::Mutex, follows Tokio's // own advice: <https://tokio.rs/tokio/tutorial/shared-state>. Methods that // lock it must be sync. - senders: Arc<Mutex<HashMap<channel::Id, Sender<broadcast::Message>>>>, + senders: Arc<Mutex<Sender<types::ChannelEvent>>>, } -impl Broadcaster { - pub async fn from_database(db: &SqlitePool) -> Result<Self, sqlx::Error> { - let mut tx = db.begin().await?; - let channels = tx.channels().all().await?; - tx.commit().await?; - - let channels = channels.iter().map(|c| &c.id); - let broadcaster = Self::new(channels); - Ok(broadcaster) - } - - fn new<'i>(channels: impl IntoIterator<Item = &'i channel::Id>) -> Self { - let senders: HashMap<_, _> = channels - .into_iter() - .cloned() - .map(|id| (id, Self::make_sender())) - .collect(); +impl Default for Broadcaster { + fn default() -> Self { + let sender = Self::make_sender(); Self { - senders: Arc::new(Mutex::new(senders)), - } - } - - // panic: if ``channel`` is already registered. - pub fn register_channel(&self, channel: &channel::Id) { - match self.senders().entry(channel.clone()) { - // This ever happening indicates a serious logic error. - Entry::Occupied(_) => panic!("duplicate channel registration for channel {channel}"), - Entry::Vacant(entry) => { - entry.insert(Self::make_sender()); - } + senders: Arc::new(Mutex::new(sender)), } } +} - // panic: if ``channel`` has not been previously registered, and was not - // part of the initial set of channels. - pub fn broadcast(&self, channel: &channel::Id, message: &broadcast::Message) { - let tx = self.sender(channel); +impl Broadcaster { + // panic: if ``message.channel.id`` has not been previously registered, + // and was not part of the initial set of channels. + pub fn broadcast(&self, message: &types::ChannelEvent) { + let tx = self.sender(); // Per the Tokio docs, the returned error is only used to indicate that // there are no receivers. In this use case, that's fine; a lack of @@ -71,15 +43,12 @@ impl Broadcaster { // panic: if ``channel`` has not been previously registered, and was not // part of the initial set of channels. - pub fn subscribe( - &self, - channel: &channel::Id, - ) -> impl Stream<Item = broadcast::Message> + std::fmt::Debug { - let rx = self.sender(channel).subscribe(); + pub fn subscribe(&self) -> impl Stream<Item = types::ChannelEvent> + std::fmt::Debug { + let rx = self.sender().subscribe(); BroadcastStream::from(rx).scan((), |(), r| { future::ready(match r { - Ok(message) => Some(message), + Ok(event) => Some(event), // Stop the stream here. This will disconnect SSE clients // (see `routes.rs`), who will then resume from // `Last-Event-ID`, allowing them to catch up by reading @@ -92,17 +61,11 @@ impl Broadcaster { }) } - // panic: if ``channel`` has not been previously registered, and was not - // part of the initial set of channels. - fn sender(&self, channel: &channel::Id) -> Sender<broadcast::Message> { - self.senders()[channel].clone() - } - - fn senders(&self) -> MutexGuard<HashMap<channel::Id, Sender<broadcast::Message>>> { - self.senders.lock().unwrap() // propagate panics when mutex is poisoned + fn sender(&self) -> Sender<types::ChannelEvent> { + self.senders.lock().unwrap().clone() } - fn make_sender() -> Sender<broadcast::Message> { + fn make_sender() -> Sender<types::ChannelEvent> { // Queue depth of 16 chosen entirely arbitrarily. Don't read too much // into it. let (tx, _) = channel(16); diff --git a/src/events/mod.rs b/src/events/mod.rs index b9f3f5b..711ae64 100644 --- a/src/events/mod.rs +++ b/src/events/mod.rs @@ -3,5 +3,6 @@ pub mod broadcaster; mod extract; pub mod repo; mod routes; +pub mod types; pub use self::routes::router; diff --git a/src/events/repo/broadcast.rs b/src/events/repo/message.rs index 6914573..b4724ea 100644 --- a/src/events/repo/broadcast.rs +++ b/src/events/repo/message.rs @@ -2,6 +2,7 @@ use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; use crate::{ clock::DateTime, + events::types::{self, Sequence}, repo::{ channel::Channel, login::{self, Login}, @@ -10,35 +11,25 @@ use crate::{ }; pub trait Provider { - fn broadcast(&mut self) -> Broadcast; + fn message_events(&mut self) -> Events; } impl<'c> Provider for Transaction<'c, Sqlite> { - fn broadcast(&mut self) -> Broadcast { - Broadcast(self) + fn message_events(&mut self) -> Events { + Events(self) } } -pub struct Broadcast<'t>(&'t mut SqliteConnection); +pub struct Events<'t>(&'t mut SqliteConnection); -#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] -pub struct Message { - pub id: message::Id, - #[serde(skip)] - pub sequence: Sequence, - pub sender: Login, - pub body: String, - pub sent_at: DateTime, -} - -impl<'c> Broadcast<'c> { +impl<'c> Events<'c> { pub async fn create( &mut self, sender: &Login, channel: &Channel, body: &str, sent_at: &DateTime, - ) -> Result<Message, sqlx::Error> { + ) -> Result<types::ChannelEvent, sqlx::Error> { let sequence = self.next_sequence_for(channel).await?; let id = message::Id::generate(); @@ -62,12 +53,16 @@ impl<'c> Broadcast<'c> { body, sent_at, ) - .map(|row| Message { - id: row.id, + .map(|row| types::ChannelEvent { sequence: row.sequence, - sender: sender.clone(), - body: row.body, - sent_at: row.sent_at, + at: row.sent_at, + channel: channel.clone(), + data: types::MessageEvent { + id: row.id, + sender: sender.clone(), + body: row.body, + } + .into(), }) .fetch_one(&mut *self.0) .await?; @@ -76,7 +71,7 @@ impl<'c> Broadcast<'c> { } async fn next_sequence_for(&mut self, channel: &Channel) -> Result<Sequence, sqlx::Error> { - let Sequence(current) = sqlx::query_scalar!( + let current = sqlx::query_scalar!( r#" -- `max` never returns null, but sqlx can't detect that select max(sequence) as "sequence!: Sequence" @@ -88,7 +83,7 @@ impl<'c> Broadcast<'c> { .fetch_one(&mut *self.0) .await?; - Ok(Sequence(current + 1)) + Ok(current.next()) } pub async fn expire(&mut self, expire_at: &DateTime) -> Result<(), sqlx::Error> { @@ -109,8 +104,8 @@ impl<'c> Broadcast<'c> { &mut self, channel: &Channel, resume_at: Option<Sequence>, - ) -> Result<Vec<Message>, sqlx::Error> { - let messages = sqlx::query!( + ) -> Result<Vec<types::ChannelEvent>, sqlx::Error> { + let events = sqlx::query!( r#" select message.id as "id: message::Id", @@ -128,35 +123,23 @@ impl<'c> Broadcast<'c> { channel.id, resume_at, ) - .map(|row| Message { - id: row.id, + .map(|row| types::ChannelEvent { sequence: row.sequence, - sender: Login { - id: row.sender_id, - name: row.sender_name, - }, - body: row.body, - sent_at: row.sent_at, + at: row.sent_at, + channel: channel.clone(), + data: types::MessageEvent { + id: row.id, + sender: login::Login { + id: row.sender_id, + name: row.sender_name, + }, + body: row.body, + } + .into(), }) .fetch_all(&mut *self.0) .await?; - Ok(messages) + Ok(events) } } - -#[derive( - Debug, - Eq, - Ord, - PartialEq, - PartialOrd, - Clone, - Copy, - serde::Serialize, - serde::Deserialize, - sqlx::Type, -)] -#[serde(transparent)] -#[sqlx(transparent)] -pub struct Sequence(i64); diff --git a/src/events/repo/mod.rs b/src/events/repo/mod.rs index 2ed3062..e216a50 100644 --- a/src/events/repo/mod.rs +++ b/src/events/repo/mod.rs @@ -1 +1 @@ -pub mod broadcast; +pub mod message; diff --git a/src/events/routes.rs b/src/events/routes.rs index d901f9b..3f70dcd 100644 --- a/src/events/routes.rs +++ b/src/events/routes.rs @@ -1,8 +1,5 @@ -use std::collections::{BTreeMap, HashSet}; - use axum::{ extract::State, - http::StatusCode, response::{ sse::{self, Sse}, IntoResponse, Response, @@ -10,87 +7,32 @@ use axum::{ routing::get, Router, }; -use axum_extra::extract::Query; -use futures::{ - future, - stream::{self, Stream, StreamExt as _, TryStreamExt as _}, -}; +use futures::stream::{Stream, StreamExt as _}; -use super::{extract::LastEventId, repo::broadcast}; -use crate::{ - app::App, - clock::RequestedAt, - error::Internal, - events::app::EventsError, - repo::{channel, login::Login}, +use super::{ + extract::LastEventId, + types::{self, ResumePoint}, }; +use crate::{app::App, clock::RequestedAt, error::Internal, repo::login::Login}; #[cfg(test)] mod test; -// For the purposes of event replay, an "event ID" is a vector of per-channel -// sequence numbers. Replay will start with messages whose sequence number in -// its channel is higher than the sequence in the event ID, or if the channel -// is not listed in the event ID, then at the beginning. -// -// Using a sorted map ensures that there is a canonical representation for -// each event ID. -type EventId = BTreeMap<channel::Id, broadcast::Sequence>; - pub fn router() -> Router<App> { Router::new().route("/api/events", get(events)) } -#[derive(Clone, serde::Deserialize)] -struct EventsQuery { - #[serde(default, rename = "channel")] - channels: HashSet<channel::Id>, -} - async fn events( State(app): State<App>, - RequestedAt(now): RequestedAt, + RequestedAt(subscribed_at): RequestedAt, _: Login, // requires auth, but doesn't actually care who you are - last_event_id: Option<LastEventId<EventId>>, - Query(query): Query<EventsQuery>, -) -> Result<Events<impl Stream<Item = ReplayableEvent> + std::fmt::Debug>, ErrorResponse> { + last_event_id: Option<LastEventId<ResumePoint>>, +) -> Result<Events<impl Stream<Item = types::ResumableEvent> + std::fmt::Debug>, Internal> { let resume_at = last_event_id .map(LastEventId::into_inner) .unwrap_or_default(); - let streams = stream::iter(query.channels) - .then(|channel| { - let app = app.clone(); - let resume_at = resume_at.clone(); - async move { - let resume_at = resume_at.get(&channel).copied(); - - let events = app - .events() - .subscribe(&channel, &now, resume_at) - .await? - .map(ChannelEvent::wrap(channel)); - - Ok::<_, EventsError>(events) - } - }) - .try_collect::<Vec<_>>() - .await - // impl From would take more code; this is used once. - .map_err(ErrorResponse)?; - - // We resume counting from the provided last-event-id mapping, rather than - // starting from scratch, so that the events in a resumed stream contain - // the full vector of channel IDs for their event IDs right off the bat, - // even before any events are actually delivered. - let stream = stream::select_all(streams).scan(resume_at, |sequences, event| { - let (channel, sequence) = event.event_id(); - sequences.insert(channel, sequence); - - let event = ReplayableEvent(sequences.clone(), event); - - future::ready(Some(event)) - }); + let stream = app.events().subscribe(&subscribed_at, resume_at).await?; Ok(Events(stream)) } @@ -100,7 +42,7 @@ struct Events<S>(S); impl<S> IntoResponse for Events<S> where - S: Stream<Item = ReplayableEvent> + Send + 'static, + S: Stream<Item = types::ResumableEvent> + Send + 'static, { fn into_response(self) -> Response { let Self(stream) = self; @@ -111,51 +53,13 @@ where } } -#[derive(Debug)] -struct ErrorResponse(EventsError); - -impl IntoResponse for ErrorResponse { - fn into_response(self) -> Response { - let Self(error) = self; - match error { - not_found @ EventsError::ChannelNotFound(_) => { - (StatusCode::NOT_FOUND, not_found.to_string()).into_response() - } - other => Internal::from(other).into_response(), - } - } -} - -#[derive(Debug)] -struct ReplayableEvent(EventId, ChannelEvent); - -#[derive(Debug, serde::Serialize)] -struct ChannelEvent { - channel: channel::Id, - #[serde(flatten)] - message: broadcast::Message, -} - -impl ChannelEvent { - fn wrap(channel: channel::Id) -> impl Fn(broadcast::Message) -> Self { - move |message| Self { - channel: channel.clone(), - message, - } - } - - fn event_id(&self) -> (channel::Id, broadcast::Sequence) { - (self.channel.clone(), self.message.sequence) - } -} - -impl TryFrom<ReplayableEvent> for sse::Event { +impl TryFrom<types::ResumableEvent> for sse::Event { type Error = serde_json::Error; - fn try_from(value: ReplayableEvent) -> Result<Self, Self::Error> { - let ReplayableEvent(id, data) = value; + fn try_from(value: types::ResumableEvent) -> Result<Self, Self::Error> { + let types::ResumableEvent(resume_at, data) = value; - let id = serde_json::to_string(&id)?; + let id = serde_json::to_string(&resume_at)?; let data = serde_json::to_string_pretty(&data)?; let event = Self::default().id(id).data(data); diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs index 4412938..f289225 100644 --- a/src/events/routes/test.rs +++ b/src/events/routes/test.rs @@ -1,40 +1,15 @@ use axum::extract::State; -use axum_extra::extract::Query; use futures::{ future, stream::{self, StreamExt as _}, }; use crate::{ - events::{app, routes}, - repo::channel::{self}, + events::{routes, types}, test::fixtures::{self, future::Immediately as _}, }; #[tokio::test] -async fn no_subscriptions() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let subscriber = fixtures::login::create(&app).await; - - // Call the endpoint - - let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [].into(), - }; - let routes::Events(mut events) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect("empty subscription"); - - // Verify the structure of the response. - - assert!(events.next().immediately().await.is_none()); -} - -#[tokio::test] async fn includes_historical_message() { // Set up the environment @@ -47,24 +22,19 @@ async fn includes_historical_message() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [channel.id.clone()].into(), - }; - let routes::Events(mut events) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect("subscribed to valid channel"); + let routes::Events(mut events) = routes::events(State(app), subscribed_at, subscriber, None) + .await + .expect("subscribe never fails"); // Verify the structure of the response. - let routes::ReplayableEvent(_, event) = events + let types::ResumableEvent(_, event) = events .next() .immediately() .await .expect("delivered stored message"); - assert_eq!(channel.id, event.channel); - assert_eq!(message, event.message); + assert_eq!(message, event); } #[tokio::test] @@ -78,68 +48,23 @@ async fn includes_live_message() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [channel.id.clone()].into(), - }; - let routes::Events(mut events) = routes::events( - State(app.clone()), - subscribed_at, - subscriber, - None, - Query(query), - ) - .await - .expect("subscribed to a valid channel"); + let routes::Events(mut events) = + routes::events(State(app.clone()), subscribed_at, subscriber, None) + .await + .expect("subscribe never fails"); // Verify the semantics let sender = fixtures::login::create(&app).await; let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - let routes::ReplayableEvent(_, event) = events + let types::ResumableEvent(_, event) = events .next() .immediately() .await .expect("delivered live message"); - assert_eq!(channel.id, event.channel); - assert_eq!(message, event.message); -} - -#[tokio::test] -async fn excludes_other_channels() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let subscribed_channel = fixtures::channel::create(&app).await; - let unsubscribed_channel = fixtures::channel::create(&app).await; - let sender = fixtures::login::create(&app).await; - let message = - fixtures::message::send(&app, &sender, &subscribed_channel, &fixtures::now()).await; - fixtures::message::send(&app, &sender, &unsubscribed_channel, &fixtures::now()).await; - - // Call the endpoint - - let subscriber = fixtures::login::create(&app).await; - let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [subscribed_channel.id.clone()].into(), - }; - let routes::Events(mut events) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect("subscribed to a valid channel"); - - // Verify the semantics - - let routes::ReplayableEvent(_, event) = events - .next() - .immediately() - .await - .expect("delivered at least one message"); - - assert_eq!(subscribed_channel.id, event.channel); - assert_eq!(message, event.message); + assert_eq!(message, event); } #[tokio::test] @@ -155,10 +80,11 @@ async fn includes_multiple_channels() { ]; let messages = stream::iter(channels) - .then(|channel| async { - let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - - (channel, message) + .then(|channel| { + let app = app.clone(); + let sender = sender.clone(); + let channel = channel.clone(); + async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await } }) .collect::<Vec<_>>() .await; @@ -167,17 +93,9 @@ async fn includes_multiple_channels() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: messages - .iter() - .map(|(channel, _)| &channel.id) - .cloned() - .collect(), - }; - let routes::Events(events) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect("subscribed to valid channels"); + let routes::Events(events) = routes::events(State(app), subscribed_at, subscriber, None) + .await + .expect("subscribe never fails"); // Verify the structure of the response. @@ -187,41 +105,14 @@ async fn includes_multiple_channels() { .immediately() .await; - for (channel, message) in messages { - assert!(events.iter().any(|routes::ReplayableEvent(_, event)| { - event.channel == channel.id && event.message == message - })); + for message in &messages { + assert!(events + .iter() + .any(|types::ResumableEvent(_, event)| { event == message })); } } #[tokio::test] -async fn nonexistent_channel() { - // Set up the environment - - let app = fixtures::scratch_app().await; - let channel = channel::Id::generate(); - - // Call the endpoint - - let subscriber = fixtures::login::create(&app).await; - let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [channel.clone()].into(), - }; - let routes::ErrorResponse(error) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect_err("subscribed to nonexistent channel"); - - // Verify the structure of the response. - - assert!(matches!( - error, - app::EventsError::ChannelNotFound(error_channel) if error_channel == channel - )); -} - -#[tokio::test] async fn sequential_messages() { // Set up the environment @@ -239,30 +130,24 @@ async fn sequential_messages() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [channel.id.clone()].into(), - }; - let routes::Events(events) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect("subscribed to a valid channel"); + let routes::Events(events) = routes::events(State(app), subscribed_at, subscriber, None) + .await + .expect("subscribe never fails"); // Verify the structure of the response. - let mut events = events.filter(|routes::ReplayableEvent(_, event)| { - future::ready(messages.contains(&event.message)) - }); + let mut events = + events.filter(|types::ResumableEvent(_, event)| future::ready(messages.contains(event))); // Verify delivery in order for message in &messages { - let routes::ReplayableEvent(_, event) = events + let types::ResumableEvent(_, event) = events .next() .immediately() .await .expect("undelivered messages remaining"); - assert_eq!(channel.id, event.channel); - assert_eq!(message, &event.message); + assert_eq!(message, &event); } } @@ -285,42 +170,28 @@ async fn resumes_from() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [channel.id.clone()].into(), - }; let resume_at = { // First subscription - let routes::Events(mut events) = routes::events( - State(app.clone()), - subscribed_at, - subscriber.clone(), - None, - Query(query.clone()), - ) - .await - .expect("subscribed to a valid channel"); + let routes::Events(mut events) = + routes::events(State(app.clone()), subscribed_at, subscriber.clone(), None) + .await + .expect("subscribe never fails"); - let routes::ReplayableEvent(id, event) = + let types::ResumableEvent(last_event_id, event) = events.next().immediately().await.expect("delivered events"); - assert_eq!(channel.id, event.channel); - assert_eq!(initial_message, event.message); + assert_eq!(initial_message, event); - id + last_event_id }; // Resume after disconnect let reconnect_at = fixtures::now(); - let routes::Events(resumed) = routes::events( - State(app), - reconnect_at, - subscriber, - Some(resume_at.into()), - Query(query), - ) - .await - .expect("subscribed to a valid channel"); + let routes::Events(resumed) = + routes::events(State(app), reconnect_at, subscriber, Some(resume_at.into())) + .await + .expect("subscribe never fails"); // Verify the structure of the response. @@ -330,11 +201,10 @@ async fn resumes_from() { .immediately() .await; - for message in later_messages { - assert!(events.iter().any( - |routes::ReplayableEvent(_, event)| event.channel == channel.id - && event.message == message - )); + for message in &later_messages { + assert!(events + .iter() + .any(|types::ResumableEvent(_, event)| event == message)); } } @@ -365,9 +235,6 @@ async fn serial_resume() { // Call the endpoint let subscriber = fixtures::login::create(&app).await; - let query = routes::EventsQuery { - channels: [channel_a.id.clone(), channel_b.id.clone()].into(), - }; let resume_at = { let initial_messages = [ @@ -377,15 +244,10 @@ async fn serial_resume() { // First subscription let subscribed_at = fixtures::now(); - let routes::Events(events) = routes::events( - State(app.clone()), - subscribed_at, - subscriber.clone(), - None, - Query(query.clone()), - ) - .await - .expect("subscribed to a valid channel"); + let routes::Events(events) = + routes::events(State(app.clone()), subscribed_at, subscriber.clone(), None) + .await + .expect("subscribe never fails"); let events = events .take(initial_messages.len()) @@ -393,13 +255,13 @@ async fn serial_resume() { .immediately() .await; - for message in initial_messages { + for message in &initial_messages { assert!(events .iter() - .any(|routes::ReplayableEvent(_, event)| event.message == message)); + .any(|types::ResumableEvent(_, event)| event == message)); } - let routes::ReplayableEvent(id, _) = events.last().expect("this vec is non-empty"); + let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty"); id.to_owned() }; @@ -421,10 +283,9 @@ async fn serial_resume() { resubscribed_at, subscriber.clone(), Some(resume_at.into()), - Query(query.clone()), ) .await - .expect("subscribed to a valid channel"); + .expect("subscribe never fails"); let events = events .take(resume_messages.len()) @@ -432,13 +293,13 @@ async fn serial_resume() { .immediately() .await; - for message in resume_messages { + for message in &resume_messages { assert!(events .iter() - .any(|routes::ReplayableEvent(_, event)| event.message == message)); + .any(|types::ResumableEvent(_, event)| event == message)); } - let routes::ReplayableEvent(id, _) = events.last().expect("this vec is non-empty"); + let types::ResumableEvent(id, _) = events.last().expect("this vec is non-empty"); id.to_owned() }; @@ -460,10 +321,9 @@ async fn serial_resume() { resubscribed_at, subscriber.clone(), Some(resume_at.into()), - Query(query.clone()), ) .await - .expect("subscribed to a valid channel"); + .expect("subscribe never fails"); let events = events .take(final_messages.len()) @@ -473,10 +333,10 @@ async fn serial_resume() { // This set of messages, in particular, _should not_ include any prior // messages from `initial_messages` or `resume_messages`. - for message in final_messages { + for message in &final_messages { assert!(events .iter() - .any(|routes::ReplayableEvent(_, event)| event.message == message)); + .any(|types::ResumableEvent(_, event)| event == message)); } }; } @@ -495,22 +355,18 @@ async fn removes_expired_messages() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { - channels: [channel.id.clone()].into(), - }; - let routes::Events(mut events) = - routes::events(State(app), subscribed_at, subscriber, None, Query(query)) - .await - .expect("subscribed to valid channel"); + + let routes::Events(mut events) = routes::events(State(app), subscribed_at, subscriber, None) + .await + .expect("subscribe never fails"); // Verify the semantics - let routes::ReplayableEvent(_, event) = events + let types::ResumableEvent(_, event) = events .next() .immediately() .await .expect("delivered messages"); - assert_eq!(channel.id, event.channel); - assert_eq!(message, event.message); + assert_eq!(message, event); } diff --git a/src/events/types.rs b/src/events/types.rs new file mode 100644 index 0000000..6747afc --- /dev/null +++ b/src/events/types.rs @@ -0,0 +1,99 @@ +use std::collections::BTreeMap; + +use crate::{ + clock::DateTime, + repo::{ + channel::{self, Channel}, + login::Login, + message, + }, +}; + +#[derive( + Debug, + Eq, + Ord, + PartialEq, + PartialOrd, + Clone, + Copy, + serde::Serialize, + serde::Deserialize, + sqlx::Type, +)] +#[serde(transparent)] +#[sqlx(transparent)] +pub struct Sequence(i64); + +impl Sequence { + pub fn next(self) -> Self { + let Self(current) = self; + Self(current + 1) + } +} + +// For the purposes of event replay, an "event ID" is a vector of per-channel +// sequence numbers. Replay will start with messages whose sequence number in +// its channel is higher than the sequence in the event ID, or if the channel +// is not listed in the event ID, then at the beginning. +// +// Using a sorted map ensures that there is a canonical representation for +// each event ID. +#[derive(Clone, Debug, Default, PartialEq, PartialOrd, serde::Deserialize, serde::Serialize)] +#[serde(transparent)] +pub struct ResumePoint(BTreeMap<channel::Id, Sequence>); + +impl ResumePoint { + pub fn singleton(channel: &channel::Id, sequence: Sequence) -> Self { + let mut vector = Self::default(); + vector.advance(channel, sequence); + vector + } + + pub fn advance(&mut self, channel: &channel::Id, sequence: Sequence) { + let Self(elements) = self; + elements.insert(channel.clone(), sequence); + } + + pub fn get(&self, channel: &channel::Id) -> Option<Sequence> { + let Self(elements) = self; + elements.get(channel).copied() + } +} +#[derive(Clone, Debug)] +pub struct ResumableEvent(pub ResumePoint, pub ChannelEvent); + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct ChannelEvent { + #[serde(skip)] + pub sequence: Sequence, + pub at: DateTime, + pub channel: Channel, + #[serde(flatten)] + pub data: ChannelEventData, +} + +impl ChannelEvent { + pub fn sequence(&self) -> ResumePoint { + ResumePoint::singleton(&self.channel.id, self.sequence) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ChannelEventData { + Message(MessageEvent), +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct MessageEvent { + pub id: message::Id, + pub sender: Login, + pub body: String, +} + +impl From<MessageEvent> for ChannelEventData { + fn from(message: MessageEvent) -> Self { + Self::Message(message) + } +} diff --git a/src/repo/channel.rs b/src/repo/channel.rs index 0186413..d223dab 100644 --- a/src/repo/channel.rs +++ b/src/repo/channel.rs @@ -16,7 +16,7 @@ impl<'c> Provider for Transaction<'c, Sqlite> { pub struct Channels<'t>(&'t mut SqliteConnection); -#[derive(Debug, Eq, PartialEq, serde::Serialize)] +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct Channel { pub id: Id, pub name: String, diff --git a/src/test/fixtures/message.rs b/src/test/fixtures/message.rs index 33feeae..bfca8cd 100644 --- a/src/test/fixtures/message.rs +++ b/src/test/fixtures/message.rs @@ -3,7 +3,7 @@ use faker_rand::lorem::Paragraphs; use crate::{ app::App, clock::RequestedAt, - events::repo::broadcast, + events::types, repo::{channel::Channel, login::Login}, }; @@ -12,7 +12,7 @@ pub async fn send( login: &Login, channel: &Channel, sent_at: &RequestedAt, -) -> broadcast::Message { +) -> types::ChannelEvent { let body = propose(); app.events() diff --git a/src/test/fixtures/mod.rs b/src/test/fixtures/mod.rs index a42dba5..450fbec 100644 --- a/src/test/fixtures/mod.rs +++ b/src/test/fixtures/mod.rs @@ -13,8 +13,6 @@ pub async fn scratch_app() -> App { .await .expect("setting up in-memory sqlite database"); App::from(pool) - .await - .expect("creating an app from a fresh, in-memory database") } pub fn now() -> RequestedAt { |
