diff options
| author | Owen Jacobson <owen@grimoire.ca> | 2024-10-02 00:41:25 -0400 |
|---|---|---|
| committer | Owen Jacobson <owen@grimoire.ca> | 2024-10-02 00:41:38 -0400 |
| commit | 357116366c1307bedaac6a3dfe9c5ed8e0e0c210 (patch) | |
| tree | d701378187d8b0f99d524991925e8348c6cab0d6 /src/event | |
| parent | f878f0b5eaa44e8ee8d67cbfd706926ff2119113 (diff) | |
First pass on reorganizing the backend.
This is primarily renames and repackagings.
Diffstat (limited to 'src/event')
| -rw-r--r-- | src/event/app.rs | 137 | ||||
| -rw-r--r-- | src/event/broadcaster.rs | 3 | ||||
| -rw-r--r-- | src/event/extract.rs | 85 | ||||
| -rw-r--r-- | src/event/mod.rs | 9 | ||||
| -rw-r--r-- | src/event/repo/message.rs | 188 | ||||
| -rw-r--r-- | src/event/repo/mod.rs | 1 | ||||
| -rw-r--r-- | src/event/routes.rs | 93 | ||||
| -rw-r--r-- | src/event/routes/test.rs | 439 | ||||
| -rw-r--r-- | src/event/sequence.rs | 24 | ||||
| -rw-r--r-- | src/event/types.rs | 97 |
10 files changed, 1076 insertions, 0 deletions
diff --git a/src/event/app.rs b/src/event/app.rs new file mode 100644 index 0000000..b5f2ecc --- /dev/null +++ b/src/event/app.rs @@ -0,0 +1,137 @@ +use chrono::TimeDelta; +use futures::{ + future, + stream::{self, StreamExt as _}, + Stream, +}; +use sqlx::sqlite::SqlitePool; + +use super::{ + broadcaster::Broadcaster, + repo::message::Provider as _, + types::{self, ChannelEvent}, +}; +use crate::{ + channel, + clock::DateTime, + event::Sequence, + login::Login, + repo::{channel::Provider as _, error::NotFound as _, sequence::Provider as _}, +}; + +pub struct Events<'a> { + db: &'a SqlitePool, + events: &'a Broadcaster, +} + +impl<'a> Events<'a> { + pub const fn new(db: &'a SqlitePool, events: &'a Broadcaster) -> Self { + Self { db, events } + } + + pub async fn send( + &self, + login: &Login, + channel: &channel::Id, + body: &str, + sent_at: &DateTime, + ) -> Result<types::ChannelEvent, EventsError> { + let mut tx = self.db.begin().await?; + let channel = tx + .channels() + .by_id(channel) + .await + .not_found(|| EventsError::ChannelNotFound(channel.clone()))?; + let sent_sequence = tx.sequence().next().await?; + let event = tx + .message_events() + .create(login, &channel, sent_at, sent_sequence, body) + .await?; + tx.commit().await?; + + self.events.broadcast(&event); + Ok(event) + } + + pub async fn expire(&self, relative_to: &DateTime) -> Result<(), sqlx::Error> { + // Somewhat arbitrarily, expire after 90 days. + let expire_at = relative_to.to_owned() - TimeDelta::days(90); + + let mut tx = self.db.begin().await?; + let expired = tx.message_events().expired(&expire_at).await?; + + let mut events = Vec::with_capacity(expired.len()); + for (channel, message) in expired { + let deleted_sequence = tx.sequence().next().await?; + let event = tx + .message_events() + .delete(&channel, &message, relative_to, deleted_sequence) + .await?; + events.push(event); + } + + tx.commit().await?; + + for event in events { + self.events.broadcast(&event); + } + + Ok(()) + } + + pub async fn subscribe( + &self, + resume_at: Option<Sequence>, + ) -> Result<impl Stream<Item = types::ChannelEvent> + std::fmt::Debug, sqlx::Error> { + // Subscribe before retrieving, to catch messages broadcast while we're + // querying the DB. We'll prune out duplicates later. + let live_messages = self.events.subscribe(); + + let mut tx = self.db.begin().await?; + let channels = tx.channels().replay(resume_at).await?; + + let channel_events = channels + .into_iter() + .map(ChannelEvent::created) + .filter(move |event| resume_at.map_or(true, |resume_at| event.sequence > resume_at)); + + let message_events = tx.message_events().replay(resume_at).await?; + + let mut replay_events = channel_events + .into_iter() + .chain(message_events.into_iter()) + .collect::<Vec<_>>(); + replay_events.sort_by_key(|event| event.sequence); + let resume_live_at = replay_events.last().map(|event| event.sequence); + + let replay = stream::iter(replay_events); + + // no skip_expired or resume transforms for stored_messages, as it's + // constructed not to contain messages meeting either criterion. + // + // * skip_expired is redundant with the `tx.broadcasts().expire(…)` call; + // * resume is redundant with the resume_at argument to + // `tx.broadcasts().replay(…)`. + let live_messages = live_messages + // Filtering on the broadcast resume point filters out messages + // before resume_at, and filters out messages duplicated from + // stored_messages. + .filter(Self::resume(resume_live_at)); + + Ok(replay.chain(live_messages)) + } + + fn resume( + resume_at: Option<Sequence>, + ) -> impl for<'m> FnMut(&'m types::ChannelEvent) -> future::Ready<bool> { + move |event| future::ready(resume_at < Some(event.sequence)) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum EventsError { + #[error("channel {0} not found")] + ChannelNotFound(channel::Id), + #[error(transparent)] + DatabaseError(#[from] sqlx::Error), +} diff --git a/src/event/broadcaster.rs b/src/event/broadcaster.rs new file mode 100644 index 0000000..92f631f --- /dev/null +++ b/src/event/broadcaster.rs @@ -0,0 +1,3 @@ +use crate::{broadcast, event::types}; + +pub type Broadcaster = broadcast::Broadcaster<types::ChannelEvent>; diff --git a/src/event/extract.rs b/src/event/extract.rs new file mode 100644 index 0000000..e3021e2 --- /dev/null +++ b/src/event/extract.rs @@ -0,0 +1,85 @@ +use std::ops::Deref; + +use axum::{ + extract::FromRequestParts, + http::{request::Parts, HeaderName, HeaderValue}, +}; +use axum_extra::typed_header::TypedHeader; +use serde::{de::DeserializeOwned, Serialize}; + +// A typed header. When used as a bare extractor, reads from the +// `Last-Event-Id` HTTP header. +pub struct LastEventId<T>(pub T); + +static LAST_EVENT_ID: HeaderName = HeaderName::from_static("last-event-id"); + +impl<T> headers::Header for LastEventId<T> +where + T: Serialize + DeserializeOwned, +{ + fn name() -> &'static HeaderName { + &LAST_EVENT_ID + } + + fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error> + where + I: Iterator<Item = &'i HeaderValue>, + { + let value = values.next().ok_or_else(headers::Error::invalid)?; + let value = value.to_str().map_err(|_| headers::Error::invalid())?; + let value = serde_json::from_str(value).map_err(|_| headers::Error::invalid())?; + Ok(Self(value)) + } + + fn encode<E>(&self, values: &mut E) + where + E: Extend<HeaderValue>, + { + let Self(value) = self; + // Must panic or suppress; the trait provides no other options. + let value = serde_json::to_string(value).expect("value can be encoded as JSON"); + let value = HeaderValue::from_str(&value).expect("LastEventId is a valid header value"); + + values.extend(std::iter::once(value)); + } +} + +#[async_trait::async_trait] +impl<S, T> FromRequestParts<S> for LastEventId<T> +where + S: Send + Sync, + T: Serialize + DeserializeOwned, +{ + type Rejection = <TypedHeader<Self> as FromRequestParts<S>>::Rejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { + // This is purely for ergonomics: it allows `RequestedAt` to be extracted + // without having to wrap it in `Extension<>`. Callers _can_ still do that, + // but they aren't forced to. + let TypedHeader(requested_at) = TypedHeader::from_request_parts(parts, state).await?; + + Ok(requested_at) + } +} + +impl<T> Deref for LastEventId<T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + let Self(header) = self; + header + } +} + +impl<T> From<T> for LastEventId<T> { + fn from(value: T) -> Self { + Self(value) + } +} + +impl<T> LastEventId<T> { + pub fn into_inner(self) -> T { + let Self(value) = self; + value + } +} diff --git a/src/event/mod.rs b/src/event/mod.rs new file mode 100644 index 0000000..7ad3f9c --- /dev/null +++ b/src/event/mod.rs @@ -0,0 +1,9 @@ +pub mod app; +pub mod broadcaster; +mod extract; +pub mod repo; +mod routes; +mod sequence; +pub mod types; + +pub use self::{routes::router, sequence::Sequence}; diff --git a/src/event/repo/message.rs b/src/event/repo/message.rs new file mode 100644 index 0000000..f051fec --- /dev/null +++ b/src/event/repo/message.rs @@ -0,0 +1,188 @@ +use sqlx::{sqlite::Sqlite, SqliteConnection, Transaction}; + +use crate::{ + channel::{self, Channel}, + clock::DateTime, + event::{types, Sequence}, + login::{self, Login}, + message::{self, Message}, +}; + +pub trait Provider { + fn message_events(&mut self) -> Events; +} + +impl<'c> Provider for Transaction<'c, Sqlite> { + fn message_events(&mut self) -> Events { + Events(self) + } +} + +pub struct Events<'t>(&'t mut SqliteConnection); + +impl<'c> Events<'c> { + pub async fn create( + &mut self, + sender: &Login, + channel: &Channel, + sent_at: &DateTime, + sent_sequence: Sequence, + body: &str, + ) -> Result<types::ChannelEvent, sqlx::Error> { + let id = message::Id::generate(); + + let message = sqlx::query!( + r#" + insert into message + (id, channel, sender, sent_at, sent_sequence, body) + values ($1, $2, $3, $4, $5, $6) + returning + id as "id: message::Id", + sender as "sender: login::Id", + sent_at as "sent_at: DateTime", + sent_sequence as "sent_sequence: Sequence", + body + "#, + id, + channel.id, + sender.id, + sent_at, + sent_sequence, + body, + ) + .map(|row| types::ChannelEvent { + sequence: row.sent_sequence, + at: row.sent_at, + data: types::MessageEvent { + channel: channel.clone(), + sender: sender.clone(), + message: Message { + id: row.id, + body: row.body, + }, + } + .into(), + }) + .fetch_one(&mut *self.0) + .await?; + + Ok(message) + } + + pub async fn delete( + &mut self, + channel: &Channel, + message: &message::Id, + deleted_at: &DateTime, + deleted_sequence: Sequence, + ) -> Result<types::ChannelEvent, sqlx::Error> { + sqlx::query_scalar!( + r#" + delete from message + where id = $1 + returning 1 as "row: i64" + "#, + message, + ) + .fetch_one(&mut *self.0) + .await?; + + Ok(types::ChannelEvent { + sequence: deleted_sequence, + at: *deleted_at, + data: types::MessageDeletedEvent { + channel: channel.clone(), + message: message.clone(), + } + .into(), + }) + } + + pub async fn expired( + &mut self, + expire_at: &DateTime, + ) -> Result<Vec<(Channel, message::Id)>, sqlx::Error> { + let messages = sqlx::query!( + r#" + select + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + channel.created_at as "channel_created_at: DateTime", + channel.created_sequence as "channel_created_sequence: Sequence", + message.id as "message: message::Id" + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where sent_at < $1 + "#, + expire_at, + ) + .map(|row| { + ( + Channel { + id: row.channel_id, + name: row.channel_name, + created_at: row.channel_created_at, + created_sequence: row.channel_created_sequence, + }, + row.message, + ) + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(messages) + } + + pub async fn replay( + &mut self, + resume_at: Option<Sequence>, + ) -> Result<Vec<types::ChannelEvent>, sqlx::Error> { + let events = sqlx::query!( + r#" + select + message.id as "id: message::Id", + channel.id as "channel_id: channel::Id", + channel.name as "channel_name", + channel.created_at as "channel_created_at: DateTime", + channel.created_sequence as "channel_created_sequence: Sequence", + sender.id as "sender_id: login::Id", + sender.name as sender_name, + message.sent_at as "sent_at: DateTime", + message.sent_sequence as "sent_sequence: Sequence", + message.body + from message + join channel on message.channel = channel.id + join login as sender on message.sender = sender.id + where coalesce(message.sent_sequence > $1, true) + order by sent_sequence asc + "#, + resume_at, + ) + .map(|row| types::ChannelEvent { + sequence: row.sent_sequence, + at: row.sent_at, + data: types::MessageEvent { + channel: Channel { + id: row.channel_id, + name: row.channel_name, + created_at: row.channel_created_at, + created_sequence: row.channel_created_sequence, + }, + sender: Login { + id: row.sender_id, + name: row.sender_name, + }, + message: Message { + id: row.id, + body: row.body, + }, + } + .into(), + }) + .fetch_all(&mut *self.0) + .await?; + + Ok(events) + } +} diff --git a/src/event/repo/mod.rs b/src/event/repo/mod.rs new file mode 100644 index 0000000..e216a50 --- /dev/null +++ b/src/event/repo/mod.rs @@ -0,0 +1 @@ +pub mod message; diff --git a/src/event/routes.rs b/src/event/routes.rs new file mode 100644 index 0000000..77761ca --- /dev/null +++ b/src/event/routes.rs @@ -0,0 +1,93 @@ +use axum::{ + extract::State, + response::{ + sse::{self, Sse}, + IntoResponse, Response, + }, + routing::get, + Router, +}; +use axum_extra::extract::Query; +use futures::stream::{Stream, StreamExt as _}; + +use super::{extract::LastEventId, types}; +use crate::{ + app::App, + error::{Internal, Unauthorized}, + event::Sequence, + login::app::ValidateError, + token::extract::Identity, +}; + +#[cfg(test)] +mod test; + +pub fn router() -> Router<App> { + Router::new().route("/api/events", get(events)) +} + +#[derive(Default, serde::Deserialize)] +struct EventsQuery { + resume_point: Option<Sequence>, +} + +async fn events( + State(app): State<App>, + identity: Identity, + last_event_id: Option<LastEventId<Sequence>>, + Query(query): Query<EventsQuery>, +) -> Result<Events<impl Stream<Item = types::ChannelEvent> + std::fmt::Debug>, EventsError> { + let resume_at = last_event_id + .map(LastEventId::into_inner) + .or(query.resume_point); + + let stream = app.events().subscribe(resume_at).await?; + let stream = app.logins().limit_stream(identity.token, stream).await?; + + Ok(Events(stream)) +} + +#[derive(Debug)] +struct Events<S>(S); + +impl<S> IntoResponse for Events<S> +where + S: Stream<Item = types::ChannelEvent> + Send + 'static, +{ + fn into_response(self) -> Response { + let Self(stream) = self; + let stream = stream.map(sse::Event::try_from); + Sse::new(stream) + .keep_alive(sse::KeepAlive::default()) + .into_response() + } +} + +impl TryFrom<types::ChannelEvent> for sse::Event { + type Error = serde_json::Error; + + fn try_from(event: types::ChannelEvent) -> Result<Self, Self::Error> { + let id = serde_json::to_string(&event.sequence)?; + let data = serde_json::to_string_pretty(&event)?; + + let event = Self::default().id(id).data(data); + + Ok(event) + } +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub enum EventsError { + DatabaseError(#[from] sqlx::Error), + ValidateError(#[from] ValidateError), +} + +impl IntoResponse for EventsError { + fn into_response(self) -> Response { + match self { + Self::ValidateError(ValidateError::InvalidToken) => Unauthorized.into_response(), + other => Internal::from(other).into_response(), + } + } +} diff --git a/src/event/routes/test.rs b/src/event/routes/test.rs new file mode 100644 index 0000000..9a3b12a --- /dev/null +++ b/src/event/routes/test.rs @@ -0,0 +1,439 @@ +use axum::extract::State; +use axum_extra::extract::Query; +use futures::{ + future, + stream::{self, StreamExt as _}, +}; + +use crate::{ + event::routes, + test::fixtures::{self, future::Immediately as _}, +}; + +#[tokio::test] +async fn includes_historical_message() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let sender = fixtures::login::create(&app).await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the structure of the response. + + let event = events + .filter(fixtures::filter::messages()) + .next() + .immediately() + .await + .expect("delivered stored message"); + + assert_eq!(message, event); +} + +#[tokio::test] +async fn includes_live_message() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + let routes::Events(events) = + routes::events(State(app.clone()), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the semantics + + let sender = fixtures::login::create(&app).await; + let message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + + let event = events + .filter(fixtures::filter::messages()) + .next() + .immediately() + .await + .expect("delivered live message"); + + assert_eq!(message, event); +} + +#[tokio::test] +async fn includes_multiple_channels() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let sender = fixtures::login::create(&app).await; + + let channels = [ + fixtures::channel::create(&app, &fixtures::now()).await, + fixtures::channel::create(&app, &fixtures::now()).await, + ]; + + let messages = stream::iter(channels) + .then(|channel| { + let app = app.clone(); + let sender = sender.clone(); + let channel = channel.clone(); + async move { fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await } + }) + .collect::<Vec<_>>() + .await; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the structure of the response. + + let events = events + .filter(fixtures::filter::messages()) + .take(messages.len()) + .collect::<Vec<_>>() + .immediately() + .await; + + for message in &messages { + assert!(events.iter().any(|event| { event == message })); + } +} + +#[tokio::test] +async fn sequential_messages() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + let messages = vec![ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + let routes::Events(events) = routes::events(State(app), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the structure of the response. + + let mut events = events.filter(|event| future::ready(messages.contains(event))); + + // Verify delivery in order + for message in &messages { + let event = events + .next() + .immediately() + .await + .expect("undelivered messages remaining"); + + assert_eq!(message, &event); + } +} + +#[tokio::test] +async fn resumes_from() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + let initial_message = fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await; + + let later_messages = vec![ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + + let resume_at = { + // First subscription + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); + + let event = events + .filter(fixtures::filter::messages()) + .next() + .immediately() + .await + .expect("delivered events"); + + assert_eq!(initial_message, event); + + event.sequence + }; + + // Resume after disconnect + let routes::Events(resumed) = routes::events( + State(app), + subscriber, + Some(resume_at.into()), + Query::default(), + ) + .await + .expect("subscribe never fails"); + + // Verify the structure of the response. + + let events = resumed + .take(later_messages.len()) + .collect::<Vec<_>>() + .immediately() + .await; + + for message in &later_messages { + assert!(events.iter().any(|event| event == message)); + } +} + +// This test verifies a real bug I hit developing the vector-of-sequences +// approach to resuming events. A small omission caused the event IDs in a +// resumed stream to _omit_ channels that were in the original stream until +// those channels also appeared in the resumed stream. +// +// Clients would see something like +// * In the original stream, Cfoo=5,Cbar=8 +// * In the resumed stream, Cfoo=6 (no Cbar sequence number) +// +// Disconnecting and reconnecting a second time, using event IDs from that +// initial period of the first resume attempt, would then cause the second +// resume attempt to restart all other channels from the beginning, and not +// from where the first disconnection happened. +// +// This is a real and valid behaviour for clients! +#[tokio::test] +async fn serial_resume() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let sender = fixtures::login::create(&app).await; + let channel_a = fixtures::channel::create(&app, &fixtures::now()).await; + let channel_b = fixtures::channel::create(&app, &fixtures::now()).await; + + // Call the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = fixtures::identity::identity(&app, &subscriber_creds, &fixtures::now()).await; + + let resume_at = { + let initial_messages = [ + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + ]; + + // First subscription + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); + + let events = events + .filter(fixtures::filter::messages()) + .take(initial_messages.len()) + .collect::<Vec<_>>() + .immediately() + .await; + + for message in &initial_messages { + assert!(events.iter().any(|event| event == message)); + } + + let event = events.last().expect("this vec is non-empty"); + + event.sequence + }; + + // Resume after disconnect + let resume_at = { + let resume_messages = [ + // Note that channel_b does not appear here. The buggy behaviour + // would be masked if channel_b happened to send a new message + // into the resumed event stream. + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + ]; + + // Second subscription + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + Some(resume_at.into()), + Query::default(), + ) + .await + .expect("subscribe never fails"); + + let events = events + .filter(fixtures::filter::messages()) + .take(resume_messages.len()) + .collect::<Vec<_>>() + .immediately() + .await; + + for message in &resume_messages { + assert!(events.iter().any(|event| event == message)); + } + + let event = events.last().expect("this vec is non-empty"); + + event.sequence + }; + + // Resume after disconnect a second time + { + // At this point, we can send on either channel and demonstrate the + // problem. The resume point should before both of these messages, but + // after _all_ prior messages. + let final_messages = [ + fixtures::message::send(&app, &sender, &channel_a, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel_b, &fixtures::now()).await, + ]; + + // Third subscription + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + Some(resume_at.into()), + Query::default(), + ) + .await + .expect("subscribe never fails"); + + let events = events + .filter(fixtures::filter::messages()) + .take(final_messages.len()) + .collect::<Vec<_>>() + .immediately() + .await; + + // This set of messages, in particular, _should not_ include any prior + // messages from `initial_messages` or `resume_messages`. + for message in &final_messages { + assert!(events.iter().any(|event| event == message)); + } + }; +} + +#[tokio::test] +async fn terminates_on_token_expiry() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + // Subscribe via the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber = + fixtures::identity::identity(&app, &subscriber_creds, &fixtures::ancient()).await; + + let routes::Events(events) = + routes::events(State(app.clone()), subscriber, None, Query::default()) + .await + .expect("subscribe never fails"); + + // Verify the resulting stream's behaviour + + app.logins() + .expire(&fixtures::now()) + .await + .expect("expiring tokens succeeds"); + + // These should not be delivered. + let messages = [ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + assert!(events + .filter(|event| future::ready(messages.contains(event))) + .next() + .immediately() + .await + .is_none()); +} + +#[tokio::test] +async fn terminates_on_logout() { + // Set up the environment + + let app = fixtures::scratch_app().await; + let channel = fixtures::channel::create(&app, &fixtures::now()).await; + let sender = fixtures::login::create(&app).await; + + // Subscribe via the endpoint + + let subscriber_creds = fixtures::login::create_with_password(&app).await; + let subscriber_token = + fixtures::identity::logged_in(&app, &subscriber_creds, &fixtures::now()).await; + let subscriber = + fixtures::identity::from_token(&app, &subscriber_token, &fixtures::now()).await; + + let routes::Events(events) = routes::events( + State(app.clone()), + subscriber.clone(), + None, + Query::default(), + ) + .await + .expect("subscribe never fails"); + + // Verify the resulting stream's behaviour + + app.logins() + .logout(&subscriber.token) + .await + .expect("expiring tokens succeeds"); + + // These should not be delivered. + let messages = [ + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + fixtures::message::send(&app, &sender, &channel, &fixtures::now()).await, + ]; + + assert!(events + .filter(|event| future::ready(messages.contains(event))) + .next() + .immediately() + .await + .is_none()); +} diff --git a/src/event/sequence.rs b/src/event/sequence.rs new file mode 100644 index 0000000..9ebddd7 --- /dev/null +++ b/src/event/sequence.rs @@ -0,0 +1,24 @@ +use std::fmt; + +#[derive( + Clone, + Copy, + Debug, + Eq, + Ord, + PartialEq, + PartialOrd, + serde::Deserialize, + serde::Serialize, + sqlx::Type, +)] +#[serde(transparent)] +#[sqlx(transparent)] +pub struct Sequence(i64); + +impl fmt::Display for Sequence { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self(value) = self; + value.fmt(f) + } +} diff --git a/src/event/types.rs b/src/event/types.rs new file mode 100644 index 0000000..cd7dea6 --- /dev/null +++ b/src/event/types.rs @@ -0,0 +1,97 @@ +use crate::{ + channel::{self, Channel}, + clock::DateTime, + event::Sequence, + login::Login, + message::{self, Message}, +}; + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct ChannelEvent { + #[serde(skip)] + pub sequence: Sequence, + pub at: DateTime, + #[serde(flatten)] + pub data: ChannelEventData, +} + +impl ChannelEvent { + pub fn created(channel: Channel) -> Self { + Self { + at: channel.created_at, + sequence: channel.created_sequence, + data: CreatedEvent { channel }.into(), + } + } + + pub fn channel_id(&self) -> &channel::Id { + match &self.data { + ChannelEventData::Created(event) => &event.channel.id, + ChannelEventData::Message(event) => &event.channel.id, + ChannelEventData::MessageDeleted(event) => &event.channel.id, + ChannelEventData::Deleted(event) => &event.channel, + } + } +} + +impl<'c> From<&'c ChannelEvent> for Sequence { + fn from(event: &'c ChannelEvent) -> Self { + event.sequence + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ChannelEventData { + Created(CreatedEvent), + Message(MessageEvent), + MessageDeleted(MessageDeletedEvent), + Deleted(DeletedEvent), +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct CreatedEvent { + pub channel: Channel, +} + +impl From<CreatedEvent> for ChannelEventData { + fn from(event: CreatedEvent) -> Self { + Self::Created(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct MessageEvent { + pub channel: Channel, + pub sender: Login, + pub message: Message, +} + +impl From<MessageEvent> for ChannelEventData { + fn from(event: MessageEvent) -> Self { + Self::Message(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct MessageDeletedEvent { + pub channel: Channel, + pub message: message::Id, +} + +impl From<MessageDeletedEvent> for ChannelEventData { + fn from(event: MessageDeletedEvent) -> Self { + Self::MessageDeleted(event) + } +} + +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize)] +pub struct DeletedEvent { + pub channel: channel::Id, +} + +impl From<DeletedEvent> for ChannelEventData { + fn from(event: DeletedEvent) -> Self { + Self::Deleted(event) + } +} |
