diff options
Diffstat (limited to 'src/events')
| -rw-r--r-- | src/events/repo/broadcast.rs | 52 | ||||
| -rw-r--r-- | src/events/routes.rs | 66 | ||||
| -rw-r--r-- | src/events/routes/test.rs | 197 |
3 files changed, 269 insertions, 46 deletions
diff --git a/src/events/repo/broadcast.rs b/src/events/repo/broadcast.rs index bffe991..29dab55 100644 --- a/src/events/repo/broadcast.rs +++ b/src/events/repo/broadcast.rs @@ -24,6 +24,7 @@ pub struct Broadcast<'t>(&'t mut SqliteConnection); #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] pub struct Message { pub id: message::Id, + pub sequence: Sequence, pub sender: Login, pub body: String, pub sent_at: DateTime, @@ -37,27 +38,32 @@ impl<'c> Broadcast<'c> { body: &str, sent_at: &DateTime, ) -> Result<Message, sqlx::Error> { + let sequence = self.next_sequence_for(channel).await?; + let id = message::Id::generate(); let message = sqlx::query!( r#" insert into message - (id, sender, channel, body, sent_at) - values ($1, $2, $3, $4, $5) + (id, channel, sequence, sender, body, sent_at) + values ($1, $2, $3, $4, $5, $6) returning id as "id: message::Id", + sequence as "sequence: Sequence", sender as "sender: login::Id", body, sent_at as "sent_at: DateTime" "#, id, - sender.id, channel.id, + sequence, + sender.id, body, sent_at, ) .map(|row| Message { id: row.id, + sequence: row.sequence, sender: sender.clone(), body: row.body, sent_at: row.sent_at, @@ -68,6 +74,22 @@ impl<'c> Broadcast<'c> { Ok(message) } + async fn next_sequence_for(&mut self, channel: &Channel) -> Result<Sequence, sqlx::Error> { + let Sequence(current) = sqlx::query_scalar!( + r#" + -- `max` never returns null, but sqlx can't detect that + select max(sequence) as "sequence!: Sequence" + from message + where channel = $1 + "#, + channel.id, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(Sequence(current + 1)) + } + pub async fn expire(&mut self, expire_at: &DateTime) -> Result<(), sqlx::Error> { sqlx::query!( r#" @@ -85,12 +107,13 @@ impl<'c> Broadcast<'c> { pub async fn replay( &mut self, channel: &Channel, - resume_at: Option<&DateTime>, + resume_at: Option<Sequence>, ) -> Result<Vec<Message>, sqlx::Error> { let messages = sqlx::query!( r#" select message.id as "id: message::Id", + sequence as "sequence: Sequence", login.id as "sender_id: login::Id", login.name as sender_name, message.body, @@ -98,14 +121,15 @@ impl<'c> Broadcast<'c> { from message join login on message.sender = login.id where channel = $1 - and coalesce(sent_at > $2, true) - order by sent_at asc + and coalesce(sequence > $2, true) + order by sequence asc "#, channel.id, resume_at, ) .map(|row| Message { id: row.id, + sequence: row.sequence, sender: Login { id: row.sender_id, name: row.sender_name, @@ -119,3 +143,19 @@ impl<'c> Broadcast<'c> { Ok(messages) } } + +#[derive( + Debug, + Eq, + Ord, + PartialEq, + PartialOrd, + Clone, + Copy, + serde::Serialize, + serde::Deserialize, + sqlx::Type, +)] +#[serde(transparent)] +#[sqlx(transparent)] +pub struct Sequence(i64); diff --git a/src/events/routes.rs b/src/events/routes.rs index a6bf5d9..7731680 100644 --- a/src/events/routes.rs +++ b/src/events/routes.rs @@ -1,3 +1,5 @@ +use std::collections::{BTreeMap, HashSet}; + use axum::{ extract::State, http::StatusCode, @@ -9,8 +11,10 @@ use axum::{ Router, }; use axum_extra::extract::Query; -use chrono::{self, format::SecondsFormat}; -use futures::stream::{self, Stream, StreamExt as _, TryStreamExt as _}; +use futures::{ + future, + stream::{self, Stream, StreamExt as _, TryStreamExt as _}, +}; use super::repo::broadcast; use crate::{ @@ -25,6 +29,15 @@ use crate::{ #[cfg(test)] mod test; +// For the purposes of event replay, an "event ID" is a vector of per-channel +// sequence numbers. Replay will start with messages whose sequence number in +// its channel is higher than the sequence in the event ID, or if the channel +// is not listed in the event ID, then at the beginning. +// +// Using a sorted map ensures that there is a canonical representation for +// each event ID. +type EventId = BTreeMap<channel::Id, broadcast::Sequence>; + pub fn router() -> Router<App> { Router::new().route("/api/events", get(events)) } @@ -32,22 +45,27 @@ pub fn router() -> Router<App> { #[derive(Clone, serde::Deserialize)] struct EventsQuery { #[serde(default, rename = "channel")] - channels: Vec<channel::Id>, + channels: HashSet<channel::Id>, } async fn events( State(app): State<App>, RequestedAt(now): RequestedAt, _: Login, // requires auth, but doesn't actually care who you are - last_event_id: Option<LastEventId>, + last_event_id: Option<LastEventId<EventId>>, Query(query): Query<EventsQuery>, -) -> Result<Events<impl Stream<Item = ChannelEvent> + std::fmt::Debug>, ErrorResponse> { - let resume_at = last_event_id.as_deref(); +) -> Result<Events<impl Stream<Item = ReplayableEvent> + std::fmt::Debug>, ErrorResponse> { + let resume_at = last_event_id + .map(LastEventId::into_inner) + .unwrap_or_default(); let streams = stream::iter(query.channels) .then(|channel| { let app = app.clone(); + let resume_at = resume_at.clone(); async move { + let resume_at = resume_at.get(&channel).copied(); + let events = app .channels() .events(&channel, &now, resume_at) @@ -62,7 +80,18 @@ async fn events( // impl From would take more code; this is used once. .map_err(ErrorResponse)?; - let stream = stream::select_all(streams); + // We resume counting from the provided last-event-id mapping, rather than + // starting from scratch, so that the events in a resumed stream contain + // the full vector of channel IDs for their event IDs right off the bat, + // even before any events are actually delivered. + let stream = stream::select_all(streams).scan(resume_at, |sequences, event| { + let (channel, sequence) = event.event_id(); + sequences.insert(channel, sequence); + + let event = ReplayableEvent(sequences.clone(), event); + + future::ready(Some(event)) + }); Ok(Events(stream)) } @@ -72,7 +101,7 @@ struct Events<S>(S); impl<S> IntoResponse for Events<S> where - S: Stream<Item = ChannelEvent> + Send + 'static, + S: Stream<Item = ReplayableEvent> + Send + 'static, { fn into_response(self) -> Response { let Self(stream) = self; @@ -101,6 +130,9 @@ impl IntoResponse for ErrorResponse { } } +#[derive(Debug)] +struct ReplayableEvent(EventId, ChannelEvent); + #[derive(Debug, serde::Serialize)] struct ChannelEvent { channel: channel::Id, @@ -116,19 +148,21 @@ impl ChannelEvent { } } - fn event_id(&self) -> String { - self.message - .sent_at - .to_rfc3339_opts(SecondsFormat::AutoSi, /* use_z */ true) + fn event_id(&self) -> (channel::Id, broadcast::Sequence) { + (self.channel.clone(), self.message.sequence) } } -impl TryFrom<ChannelEvent> for sse::Event { +impl TryFrom<ReplayableEvent> for sse::Event { type Error = serde_json::Error; - fn try_from(value: ChannelEvent) -> Result<Self, Self::Error> { - let data = serde_json::to_string_pretty(&value)?; - let event = Self::default().id(value.event_id()).data(&data); + fn try_from(value: ReplayableEvent) -> Result<Self, Self::Error> { + let ReplayableEvent(id, data) = value; + + let id = serde_json::to_string(&id)?; + let data = serde_json::to_string_pretty(&data)?; + + let event = Self::default().id(id).data(data); Ok(event) } diff --git a/src/events/routes/test.rs b/src/events/routes/test.rs index df2d5f6..131c751 100644 --- a/src/events/routes/test.rs +++ b/src/events/routes/test.rs @@ -22,7 +22,9 @@ async fn no_subscriptions() { // Call the endpoint let subscribed_at = fixtures::now(); - let query = routes::EventsQuery { channels: vec![] }; + let query = routes::EventsQuery { + channels: [].into(), + }; let routes::Events(mut events) = routes::events(State(app), subscribed_at, subscriber, None, Query(query)) .await @@ -47,7 +49,7 @@ async fn includes_historical_message() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); let query = routes::EventsQuery { - channels: vec![channel.id.clone()], + channels: [channel.id.clone()].into(), }; let routes::Events(mut events) = routes::events(State(app), subscribed_at, subscriber, None, Query(query)) @@ -56,7 +58,7 @@ async fn includes_historical_message() { // Verify the structure of the response. - let event = events + let routes::ReplayableEvent(_, event) = events .next() .immediately() .await @@ -78,7 +80,7 @@ async fn includes_live_message() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); let query = routes::EventsQuery { - channels: vec![channel.id.clone()], + channels: [channel.id.clone()].into(), }; let routes::Events(mut events) = routes::events( State(app.clone()), @@ -95,7 +97,7 @@ async fn includes_live_message() { let sender = fixtures::login::create(&app).await; let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; - let event = events + let routes::ReplayableEvent(_, event) = events .next() .immediately() .await @@ -121,7 +123,7 @@ async fn excludes_other_channels() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); let query = routes::EventsQuery { - channels: vec![subscribed.id.clone()], + channels: [subscribed.id.clone()].into(), }; let routes::Events(mut events) = routes::events(State(app), subscribed_at, subscriber, None, Query(query)) @@ -130,7 +132,7 @@ async fn excludes_other_channels() { // Verify the semantics - let event = events + let routes::ReplayableEvent(_, event) = events .next() .immediately() .await @@ -186,9 +188,9 @@ async fn includes_multiple_channels() { .await; for (channel, message) in messages { - assert!(events - .iter() - .any(|event| { event.channel == channel.id && event.message == message })); + assert!(events.iter().any(|routes::ReplayableEvent(_, event)| { + event.channel == channel.id && event.message == message + })); } } @@ -204,7 +206,7 @@ async fn nonexitent_channel() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); let query = routes::EventsQuery { - channels: vec![channel.clone()], + channels: [channel.clone()].into(), }; let routes::ErrorResponse(error) = routes::events(State(app), subscribed_at, subscriber, None, Query(query)) @@ -239,7 +241,7 @@ async fn sequential_messages() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); let query = routes::EventsQuery { - channels: vec![channel.id.clone()], + channels: [channel.id.clone()].into(), }; let routes::Events(events) = routes::events(State(app), subscribed_at, subscriber, None, Query(query)) @@ -248,11 +250,13 @@ async fn sequential_messages() { // Verify the structure of the response. - let mut events = events.filter(|event| future::ready(messages.contains(&event.message))); + let mut events = events.filter(|routes::ReplayableEvent(_, event)| { + future::ready(messages.contains(&event.message)) + }); // Verify delivery in order for message in &messages { - let event = events + let routes::ReplayableEvent(_, event) = events .next() .immediately() .await @@ -283,7 +287,7 @@ async fn resumes_from() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); let query = routes::EventsQuery { - channels: vec![channel.id.clone()], + channels: [channel.id.clone()].into(), }; let resume_at = { @@ -298,19 +302,20 @@ async fn resumes_from() { .await .expect("subscribed to a valid channel"); - let event = events.next().immediately().await.expect("delivered events"); + let routes::ReplayableEvent(id, event) = + events.next().immediately().await.expect("delivered events"); assert_eq!(channel.id, event.channel); assert_eq!(initial_message, event.message); - event.event_id() + id }; // Resume after disconnect - let resumed_at = fixtures::now(); + let reconnect_at = fixtures::now(); let routes::Events(resumed) = routes::events( State(app), - resumed_at, + reconnect_at, subscriber, Some(resume_at.into()), Query(query), @@ -327,12 +332,156 @@ async fn resumes_from() { .await; for message in later_messages { - assert!(events - .iter() - .any(|event| event.channel == channel.id && event.message == message)); + assert!(events.iter().any( + |routes::ReplayableEvent(_, event)| event.channel == channel.id + && event.message == message + )); } } +// This test verifies a real bug I hit developing the vector-of-sequences +// approach to resuming events. A small omission caused the event IDs in a +// resumed stream to _omit_ channels that were in the original stream until +// those channels also appeared in the resumed stream. +// +// Clients would see something like +// * In the original stream, Cfoo=5,Cbar=8 +// * In the resumed stream, Cfoo=6 (no Cbar sequence number) +// +// Disconnecting and reconnecting a second time, using event IDs from that +// initial period of the first resume attempt, would then cause the second +// resume attempt to restart all other channels from the beginning, and not +// from where the first disconnection happened. +// +// This is a real and valid behaviour for clients! +#[tokio::test] +async fn serial_resume() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let sender = fixtures::login::create(&app).await; + let channel_a = fixtures::channel::create(&app).await; + let channel_b = fixtures::channel::create(&app).await; + + // Call the endpoint + + let subscriber = fixtures::login::create(&app).await; + let query = routes::EventsQuery { + channels: [channel_a.id.clone(), channel_b.id.clone()].into(), + }; + + let resume_at = { + let initial_messages = [ + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + ]; + + // First subscription + let subscribed_at = fixtures::now(); + let routes::Events(events) = routes::events( + State(app.clone()), + subscribed_at, + subscriber.clone(), + None, + Query(query.clone()), + ) + .await + .expect("subscribed to a valid channel"); + + let events = events + .take(initial_messages.len()) + .collect::<Vec<_>>() + .immediately() + .await; + + for message in initial_messages { + assert!(events + .iter() + .any(|routes::ReplayableEvent(_, event)| event.message == message)); + } + + let routes::ReplayableEvent(id, _) = events.last().expect("this vec is non-empty"); + + id.to_owned() + }; + + // Resume after disconnect + let resume_at = { + let resume_messages = [ + // Note that channel_b does not appear here. The buggy behaviour + // would be masked if channel_b happened to send a new message + // into the resumed event stream. + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + ]; + + // Second subscription + let resubscribed_at = fixtures::now(); + let routes::Events(events) = routes::events( + State(app.clone()), + resubscribed_at, + subscriber.clone(), + Some(resume_at.into()), + Query(query.clone()), + ) + .await + .expect("subscribed to a valid channel"); + + let events = events + .take(resume_messages.len()) + .collect::<Vec<_>>() + .immediately() + .await; + + for message in resume_messages { + assert!(events + .iter() + .any(|routes::ReplayableEvent(_, event)| event.message == message)); + } + + let routes::ReplayableEvent(id, _) = events.last().expect("this vec is non-empty"); + + id.to_owned() + }; + + // Resume after disconnect a second time + { + // At this point, we can send on either channel and demonstrate the + // problem. The resume point should before both of these messages, but + // after _all_ prior messages. + let final_messages = [ + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + ]; + + // Second subscription + let resubscribed_at = fixtures::now(); + let routes::Events(events) = routes::events( + State(app.clone()), + resubscribed_at, + subscriber.clone(), + Some(resume_at.into()), + Query(query.clone()), + ) + .await + .expect("subscribed to a valid channel"); + + let events = events + .take(final_messages.len()) + .collect::<Vec<_>>() + .immediately() + .await; + + // This set of messages, in particular, _should not_ include any prior + // messages from `initial_messages` or `resume_messages`. + for message in final_messages { + assert!(events + .iter() + .any(|routes::ReplayableEvent(_, event)| event.message == message)); + } + }; +} + #[tokio::test] async fn removes_expired_messages() { // Set up the environment @@ -348,7 +497,7 @@ async fn removes_expired_messages() { let subscriber = fixtures::login::create(&app).await; let subscribed_at = fixtures::now(); let query = routes::EventsQuery { - channels: vec![channel.id.clone()], + channels: [channel.id.clone()].into_iter().collect(), }; let routes::Events(mut events) = routes::events(State(app), subscribed_at, subscriber, None, Query(query)) @@ -357,7 +506,7 @@ async fn removes_expired_messages() { // Verify the semantics - let event = events + let routes::ReplayableEvent(_, event) = events .next() .immediately() .await |
